Skip to content

Commit

Permalink
[ENH] Add base CortexSpatial model (#536)
Browse files Browse the repository at this point in the history
* [MNT] Update requirements.txt (#507)

* [DOC] Fix gallery thumbnail images (#510)

* [FIX] Add check for empty stimulus (#522)

* add check for empty stimulus in stim setter for implants

* add check for empty np.ndarray

* [FIX] Fix electrode numbering annotation in implant.plot() (#523)

* fix implant annotation

* used zorder

* [MNT][FIX] Remove outdated ubuntu workflows, add Jax version requirement (#529)

* Remove 18.04, add 20.04

* Add python version to display name

* remove python version

* add jax version

* Fix jax version

* test for wheels: remove jax version

* Add jax version back

* test for wheels: remove jax version

* Add jax version back

* test for wheels

* also skip win32 for python3.7

* add separate cortex spatial model

---------

Co-authored-by: isaac hoffman <trsileneh@gmail.com>
  • Loading branch information
jgranley and tallyhawley committed Mar 21, 2023
1 parent 88eb591 commit a062330
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 104 deletions.
3 changes: 2 additions & 1 deletion pulse2percept/models/cortex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
* :ref:`Basic Concepts > Computational Models <topics-models>`
"""
from .base import ScoreboardModel, ScoreboardSpatial
from .base import ScoreboardModel, ScoreboardSpatial, CortexSpatial


__all__ = [
'CortexSpatial',
'ScoreboardModel',
'ScoreboardSpatial'
]
199 changes: 96 additions & 103 deletions pulse2percept/models/cortex/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,104 @@
from ...topography import Polimeni2006Map
from .._beyeler2019 import fast_scoreboard
from ...utils.constants import ZORDER
import warnings
import numpy as np

class ScoreboardSpatial(SpatialModel):
class CortexSpatial(SpatialModel):
"""Abstract base class for cortical models
This is an abstract class that cortical models can subclass
to get cortical implementation of the following features.
1) Updated default parameters for cortex
2) Handling of multiple visual regions via regions property
3) Plotting, including multiple visual regions, legends, vertical
divide at longitudinal fissure, etc.
"""
@property
def regions(self):
return self._regions

@regions.setter
def regions(self, regions):

if not isinstance(regions, list):
regions = [regions]
self._regions = regions

def __init__(self, **params):
self._regions = None
super(CortexSpatial, self).__init__(**params)

# Use [Polemeni2006]_ visual field map by default
if 'retinotopy' not in params.keys():
self.retinotopy = Polimeni2006Map(regions=self.regions)
elif 'regions' in params.keys() and \
set(self.regions) != set(self.retinotopy.regions):
raise ValueError("Conflicting regions in provided retinotopy and regions")
else:
# need to override self.regions
self.regions = self.retinotopy.regions

if not isinstance(self.regions, list):
self.regions = [self.regions]

def get_default_params(self):
"""Returns all settable parameters of the scoreboard model"""
base_params = super(CortexSpatial, self).get_default_params()
params = {
'xrange' : (-5, 5),
'yrange' : (-5, 5),
'xystep' : 0.1,
# Visual field regions to simulate
'regions' : ['v1']
}
return {**base_params, **params}


def plot(self, use_dva=False, style=None, autoscale=True, ax=None,
figsize=None, fc=None):
"""Plot the model
Parameters
----------
use_dva : bool, optional
Plot points in visual field. If false, simulated points will be
plotted in cortex
style : {'hull', 'scatter', 'cell'}, optional
Grid plotting style:
* 'hull': Show the convex hull of the grid (that is, the outline of
the smallest convex set that contains all grid points).
* 'scatter': Scatter plot all grid points
* 'cell': Show the outline of each grid cell as a polygon. Note that
this can be costly for a high-resolution grid.
autoscale : bool, optional
Whether to adjust the x,y limits of the plot to fit the implant
ax : matplotlib.axes._subplots.AxesSubplot, optional
A Matplotlib axes object. If None, will either use the current axes
(if exists) or create a new Axes object.
figsize : (float, float), optional
Desired (width, height) of the figure in inches
Returns
-------
ax : ``matplotlib.axes.Axes``
Returns the axis object of the plot
"""
if style is None:
style = 'hull' if use_dva else 'scatter'
ax = self.grid.plot(style=style, use_dva=use_dva, autoscale=autoscale,
ax=ax, figsize=figsize, fc=fc,
zorder=ZORDER['background'],
legend=True if not use_dva else False)
if use_dva:
ax.set_xlabel('x (dva)')
ax.set_ylabel('y (dva)')
else:
ax.set_xticklabels(np.array(ax.get_xticks()) / 1000)
ax.set_yticklabels(np.array(ax.get_yticks()) / 1000)
ax.set_xlabel('x (mm)')
ax.set_ylabel('y (mm)')
return ax


class ScoreboardSpatial(CortexSpatial):
"""Cortical adaptation of scoreboard model from [Beyeler2019]_
Implements the scoreboard model described in [Beyeler2019]_, where percepts
Expand Down Expand Up @@ -69,76 +163,18 @@ class ScoreboardSpatial(SpatialModel):
``model.build()`` again for your changes to take effect.
"""
@property
def regions(self):
return self._regions

@regions.setter
def regions(self, regions):

if not isinstance(regions, list):
regions = [regions]
self._regions = regions

def __init__(self, **params):
self._regions = None
super(ScoreboardSpatial, self).__init__(**params)

# Use [Polemeni2006]_ visual field map by default
if 'retinotopy' not in params.keys():
self.retinotopy = Polimeni2006Map(regions=self.regions)
elif 'regions' in params.keys() and \
set(self.regions) != set(self.retinotopy.regions):
raise ValueError("Conflicting regions in provided retinotopy and regions")
else:
# need to override self.regions
self.regions = self.retinotopy.regions

if not isinstance(self.regions, list):
self.regions = [self.regions]

def get_default_params(self):
"""Returns all settable parameters of the scoreboard model"""
base_params = super(ScoreboardSpatial, self).get_default_params()
params = {
'xrange' : (-5, 5),
'yrange' : (-5, 5),
'xystep' : 0.1,
# radial current spread
'rho': 200,
# Visual field regions to simulate
'regions' : ['v1']
}
return {**base_params, **params}

def _build(self):
# warn the user either that they are simulating points at discontinuous boundaries,
# or that the points will be moved by a small constant
if np.any(self.grid['dva'].x == 0):
if hasattr(self.retinotopy, 'jitter_boundary') and self.retinotopy.jitter_boundary:
warnings.warn("Since the visual cortex is discontinuous " +
"across hemispheres, it is recommended to not simulate points " +
" at exactly x=0. Points on the boundary will be moved " +
"by a small constant")
else:
warnings.warn("Since the visual cortex is discontinuous " +
"across hemispheres, it is recommended to not simulate points " +
" at exactly x=0. This can be avoided by adding a small " +
"to both limits of xrange")
if (np.any([r in self.regions for r in self.grid.discontinuous_y]) and
np.any(self.grid['dva'].y == 0)):
if hasattr(self.retinotopy, 'jitter_boundary') and self.retinotopy.jitter_boundary:
warnings.warn("Since some simulated regions are discontinuous " +
"across the y axis, it is recommended to not simulate points " +
" at exactly y=0. Points on the boundary will be moved " +
"by a small constant")
else:
warnings.warn(f"Since some simulated regions are discontinuous " +
"across the y axis, it is recommended to not simulate points " +
" at exactly y=0. This can be avoided by adding a small " +
"to both limits of yrange or setting " +
"self.retinotopy.jitter_boundary=True")

def _predict_spatial(self, earray, stim):
"""Predicts the brightness at spatial locations"""
x_el = np.array([earray[e].x for e in stim.electrodes],
Expand All @@ -160,49 +196,6 @@ def _predict_spatial(self, earray, stim):
self.n_threads)
for region in self.regions ],
axis = 0)

def plot(self, use_dva=False, style=None, autoscale=True, ax=None,
figsize=None, fc=None):
"""Plot the model
Parameters
----------
use_dva : bool, optional
Plot points in visual field. If false, simulated points will be
plotted in cortex
style : {'hull', 'scatter', 'cell'}, optional
Grid plotting style:
* 'hull': Show the convex hull of the grid (that is, the outline of
the smallest convex set that contains all grid points).
* 'scatter': Scatter plot all grid points
* 'cell': Show the outline of each grid cell as a polygon. Note that
this can be costly for a high-resolution grid.
autoscale : bool, optional
Whether to adjust the x,y limits of the plot to fit the implant
ax : matplotlib.axes._subplots.AxesSubplot, optional
A Matplotlib axes object. If None, will either use the current axes
(if exists) or create a new Axes object.
figsize : (float, float), optional
Desired (width, height) of the figure in inches
Returns
-------
ax : ``matplotlib.axes.Axes``
Returns the axis object of the plot
"""
if style is None:
style = 'hull' if use_dva else 'scatter'
ax = self.grid.plot(style=style, use_dva=use_dva, autoscale=autoscale,
ax=ax, figsize=figsize, fc=fc,
zorder=ZORDER['background'],
legend=True if not use_dva else False)
if use_dva:
ax.set_xlabel('x (dva)')
ax.set_ylabel('y (dva)')
else:
ax.set_xticklabels(np.array(ax.get_xticks()) / 1000)
ax.set_yticklabels(np.array(ax.get_yticks()) / 1000)
ax.set_xlabel('x (mm)')
ax.set_ylabel('y (mm)')
return ax


class ScoreboardModel(Model):
Expand Down

0 comments on commit a062330

Please sign in to comment.