Skip to content

Commit

Permalink
[ENH] Update grid.plot for cortical transforms (#534)
Browse files Browse the repository at this point in the history
* grid interface changes

* changed valueerror to runtime error in percept save test

* Fix unique time point bug

* temp commit to store grid stuff

* [MNT] Update requirements.txt (#507)

* [DOC] Fix gallery thumbnail images (#510)

* [FIX] Add check for empty stimulus (#522)

* add check for empty stimulus in stim setter for implants

* add check for empty np.ndarray

* [FIX] Fix electrode numbering annotation in implant.plot() (#523)

* fix implant annotation

* used zorder

* [ENH][REF] Modify Grid2D and VisualFieldMap to support multiple visual areas (#509)

* grid interface changes

* changed valueerror to runtime error in percept save test

* Fix unique time point bug

* temp commit to store grid stuff

* update requirements

* update base.py for new grid class

* Grid class now supports multiple layers

* refactor layer to be region

* Add region_mappings, RetinalMap

* Fixed overwriting static attributes

* Base class for cortical models

* [MNT] Update requirements.txt (#507)

* Add tests, made inv transforms optional to overwrite

* update with named tuple coordinate grids, and ret_to_dva etc
"

* refactor everything to ret_to_dva

* Update static ret2dva references

* removed backwards compatibility for ret2dva

* update doc

* [REF] Add topography, implants.cortex, and models.cortex submodules (#518)

* [MNT] Update requirements.txt (#507)

* [DOC] Fix gallery thumbnail images (#510)

* add topography module

* fix imports

* refactor tests for topography module

* doc

* add epmty submodules for implants.cortex and models.cortex

* add orion implant and test

* update cortex __init__.py

* add orion implant and test

* add scoreboard cortex

* Replace outdated references to utils.*Map

* fix unrelated jax error: require older jax

* add doc and skeleton tests

* docstrings

* req

* add cortex model, changes to polimeni map

* bug fixes

* add some initial tests

* add equality check

* visual field maps inherit from basemodel

* add offset to polimeni, allow polimeni to cross x axis

* fix inverse transforms

* finished tests

* update topography tests for float32

* make test simpler

* test w print for mx

* test w print for mx

* more tests for mac

* more tests for mac

* improve numerical stability

* improve numerical stability test

* remove debug, add comments

* loop through transforms in grid2d plot

* change grid plot to use retinotopy

* add documentation for use_dva param in grid2d plot

* fix style=hull plot bug, changed default colors

* change default grid plot color to gray, fix hull style when using use_dva=True, make discontinuity check cleaner

* update default plotting, tests

---------

Co-authored-by: jgranley <jgranley@ucsb.edu>
Co-authored-by: isaac hoffman <trsileneh@gmail.com>
  • Loading branch information
3 people committed Mar 21, 2023
1 parent 5cce8c5 commit 88eb591
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 117 deletions.
11 changes: 6 additions & 5 deletions pulse2percept/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,15 +449,16 @@ def plot(self, use_dva=False, style='hull', autoscale=True, ax=None,
"""
if not self.is_built:
self.build()

zorder = ZORDER['background'] + (0 if use_dva else 1)

ax = self.grid.plot(autoscale=autoscale, ax=ax, style=style, zorder=zorder,
figsize=figsize, use_dva=use_dva)

if use_dva:
ax = self.grid.plot(autoscale=autoscale, ax=ax, style=style,
zorder=ZORDER['background'], figsize=figsize)
ax.set_xlabel('x (dva)')
ax.set_ylabel('y (dva)')
else:
ax = self.grid.plot(transform=self.retinotopy.dva_to_ret, ax=ax,
zorder=ZORDER['background'] + 1, style=style,
figsize=figsize, autoscale=autoscale)
ax.set_xlabel('x (microns)')
ax.set_ylabel('y (microns)')
return ax
Expand Down
6 changes: 2 additions & 4 deletions pulse2percept/models/beyeler2019.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,6 @@ def plot(self, use_dva=False, style='hull', annotate=True, autoscale=True,
od_xy = self.loc_od
od_w = 6.44
od_h = 6.85
grid_transform = None
# Flip y upside down for dva:
axon_bundles = [np.array(self.retinotopy.ret_to_dva(bundle[:, 0],
-bundle[:, 1])).T
Expand All @@ -809,7 +808,6 @@ def plot(self, use_dva=False, style='hull', annotate=True, autoscale=True,
od_xy = self.retinotopy.dva_to_ret(*self.loc_od)
od_w = 1770
od_h = 1880
grid_transform = self.retinotopy.dva_to_ret
if self.eye == 'RE':
labels = ['superior', 'inferior', 'temporal', 'nasal']
else:
Expand All @@ -830,8 +828,8 @@ def plot(self, use_dva=False, style='hull', annotate=True, autoscale=True,
color='white', zorder=ZORDER['background'] + 1))
# Show extent of simulated grid:
if self.is_built:
self.grid.plot(ax=ax, transform=grid_transform, style=style,
zorder=ZORDER['background'] + 2)
self.grid.plot(ax=ax, style=style, zorder=ZORDER['background'] + 2,
use_dva=use_dva)
ax.set_xlabel(f'x ({units})')
ax.set_ylabel(f'y ({units})')
if autoscale:
Expand Down
38 changes: 11 additions & 27 deletions pulse2percept/models/cortex/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,20 +160,17 @@ def _predict_spatial(self, earray, stim):
self.n_threads)
for region in self.regions ],
axis = 0)


def plot(self, use_dva=False, style='scatter', autoscale=True, ax=None,
figsize=None):

def plot(self, use_dva=False, style=None, autoscale=True, ax=None,
figsize=None, fc=None):
"""Plot the model
Parameters
----------
use_dva : bool, optional
Uses degrees of visual angle (dva) if True, else retinal
coordinates (microns)
Plot points in visual field. If false, simulated points will be
plotted in cortex
style : {'hull', 'scatter', 'cell'}, optional
Grid plotting style:
* 'hull': Show the convex hull of the grid (that is, the outline of
the smallest convex set that contains all grid points).
* 'scatter': Scatter plot all grid points
Expand All @@ -186,34 +183,21 @@ def plot(self, use_dva=False, style='scatter', autoscale=True, ax=None,
(if exists) or create a new Axes object.
figsize : (float, float), optional
Desired (width, height) of the figure in inches
Returns
-------
ax : ``matplotlib.axes.Axes``
Returns the axis object of the plot
"""
if not self.is_built:
self.build()
if style is None:
style = 'hull' if use_dva else 'scatter'
ax = self.grid.plot(style=style, use_dva=use_dva, autoscale=autoscale,
ax=ax, figsize=figsize, fc=fc,
zorder=ZORDER['background'],
legend=True if not use_dva else False)
if use_dva:
ax = self.grid.plot(autoscale=autoscale, ax=ax, style=style,
zorder=ZORDER['background'], figsize=figsize)
ax.set_xlabel('x (dva)')
ax.set_ylabel('y (dva)')
else:
for idx_region, region in enumerate(self.retinotopy.regions):
transform = self.retinotopy.from_dva()[region]
if region == 'v1':
fc = 'red'
elif region == 'v2':
fc = 'orange'
elif region == 'v3':
fc = 'green'
ax = self.grid.plot(transform=transform, label=region, ax=ax,
zorder=ZORDER['background'] + 1, style=style,
figsize=figsize, autoscale=autoscale, fc=fc)


ax.legend(loc='upper right')
ax.set_xticklabels(np.array(ax.get_xticks()) / 1000)
ax.set_yticklabels(np.array(ax.get_yticks()) / 1000)
ax.set_xlabel('x (mm)')
Expand Down
172 changes: 100 additions & 72 deletions pulse2percept/topography/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
from matplotlib.collections import PatchCollection
import matplotlib as mpl


from ..utils.base import PrettyPrint
from ..utils.constants import ZORDER
Expand Down Expand Up @@ -120,6 +122,7 @@ def __init__(self, x_range, y_range, step=1, grid_type='rectangular'):
self.y_range = y_range
self.step = step
self.type = grid_type
self.retinotopy = None
self.regions = []
self.retinotopy = None
# Datatype for storing the grid of coordinates
Expand Down Expand Up @@ -214,20 +217,12 @@ def build(self, retinotopy):
if region not in self.all_regions:
self._register_regions([region])

def plot(self, transform=None, label=None, style='hull', autoscale=True,
zorder=None, ax=None, figsize=None, fc='gray'):
def plot(self, style='hull', autoscale=True, zorder=None, ax=None,
figsize=None, fc=None, use_dva=False, legend=False):
"""Plot the extension of the grid
Parameters
----------
transform : function, optional
A coordinate transform to be applied to the (x,y) coordinates of
the grid (e.g., :py:meth:`Curcio1990Transform.dva_to_ret`). It must
accept two input arguments (x and y) and output two variables (the
transformed x and y).
label : str, optional
A name to be used as the label of the matplotlib plot. This can be used
to label plots with multiple regions (i.e. call plt.legend after)
style : {'hull', 'scatter', 'cell'}, optional
* 'hull': Show the convex hull of the grid (that is, the outline of
the smallest convex set that contains all grid points).
Expand All @@ -246,6 +241,13 @@ def plot(self, transform=None, label=None, style='hull', autoscale=True,
fc : str or valid matplotlib color, optional
Facecolor, or edge color if style=scatter, of the plotted region
Defaults to gray
use_dva : bool, optional
Whether dva or transformed points should be plotted. If True, will
not apply any transformations, and if False, will apply all
transformations in self.retinotopy
legend : bool, optional
Whether to add a plot legend. The legend is always added if there
are 2 or more regions. This only applies if there is 1 region.
"""
if style.lower() not in ['hull', 'scatter', 'cell']:
raise ValueError(f'Unknown plotting style "{style}". Choose from: '
Expand All @@ -268,79 +270,105 @@ def plot(self, transform=None, label=None, style='hull', autoscale=True,
x_step = self.step
y_step = self.step

if style.lower() == 'cell':
# Show a polygon for every grid cell that we are simulating:
if self.type == 'hexagonal':
raise NotImplementedError
patches = []
for xret, yret in zip(x.ravel(), y.ravel()):
# Outlines of the cell are given by (x,y) and the step size:
vertices = np.array([
[xret - x_step / 2, yret - y_step / 2],
[xret - x_step / 2, yret + y_step / 2],
[xret + x_step / 2, yret + y_step / 2],
[xret + x_step / 2, yret - y_step / 2],
])
# Check for discontinuity.
# This is super hacky, but different regions need to be plotted
# differently, and it can't be implemented from outside this fn
# because it depends not only on retinotopy, but also transform.
# If region is discontinuous and vertices cross boundary, skip
# TODO Luke: This can be modified to no longer depend on the
# transform name when the transform parameter
# is no longer passed (because we'll know region then)
if (transform and
np.any([r in transform.__name__ for r in self.discontinuous_x]) and
np.sign(vertices[0][0]) != np.sign(vertices[2][0])):
continue
if (transform and
np.any([r in transform.__name__ for r in self.discontinuous_y]) and
np.sign(vertices[0][1]) != np.sign(vertices[1][1])):
continue
transforms = [('dva', None)]
if not use_dva:
transforms = self.retinotopy.from_dva().items()

color_map = {
'ret' : 'gray',
'dva' : 'gray',
'v1' : 'red',
'v2' : 'orange',
'v3' : 'green'
}
# for tracking legend items when style='cell'
legends = []
for idx, (label, transform) in enumerate(transforms):
if fc is not None:
color = fc[label] if isinstance(fc, dict) else fc
elif label in color_map.keys():
color = color_map[label]
else:
color = 'gray'

if style.lower() == 'cell':
# Show a polygon for every grid cell that we are simulating:
if self.type == 'hexagonal':
raise NotImplementedError
patches = []
for xret, yret in zip(x.ravel(), y.ravel()):
# Outlines of the cell are given by (x,y) and the step size:
vertices = np.array([
[xret - x_step / 2, yret - y_step / 2],
[xret - x_step / 2, yret + y_step / 2],
[xret + x_step / 2, yret + y_step / 2],
[xret + x_step / 2, yret - y_step / 2],
])
# If region is discontinuous and vertices cross boundary, skip
if (transform and
label in self.discontinuous_x and
np.sign(vertices[0][0]) != np.sign(vertices[2][0])):
continue
if (transform and
label in self.discontinuous_y and
np.sign(vertices[0][1]) != np.sign(vertices[1][1])):
continue
# transform the points
if transform is not None:
vertices = np.array(transform(*vertices.T)).T
patches.append(Polygon(vertices, alpha=0.3, ec='k', fc=color,
ls='--', zorder=zorder, label=label))
legends.append(patches[0])
ax.add_collection(PatchCollection(patches, match_original=True,
zorder=zorder, label=label))
else:
# Show either the convex hull or a scatter plot:
if transform is not None:
vertices = np.array(transform(*vertices.T)).T
patches.append(Polygon(vertices, alpha=0.3, ec='k', fc=fc,
ls='--', zorder=zorder))
ax.add_collection(PatchCollection(patches, match_original=True,
zorder=zorder, label=label))
else:
# Show either the convex hull or a scatter plot:
if transform is not None:
x, y = transform(self.x, self.y)
points = np.vstack((x.ravel(), y.ravel()))
# Remove NaN values from the grid:
points = points[:, ~np.logical_or(*np.isnan(points))]
if style.lower() == 'hull':
if self.retinotopy and self.retinotopy.split_map:
points_right = points[:, points[0] >= 0]
points_left = points[:, points[0] <= 0]
hull_right = ConvexHull(points_right.T)
hull_left = ConvexHull(points_left.T)
ax.add_patch(Polygon(points_right[:, hull_right.vertices].T, alpha=0.3, ec='k',
fc=fc, ls='--', zorder=zorder, label=label))
ax.add_patch(Polygon(points_left[:, hull_left.vertices].T, alpha=0.3, ec='k',
fc=fc, ls='--', zorder=zorder))
else:
hull = ConvexHull(points.T)
ax.add_patch(Polygon(points[:, hull.vertices].T, alpha=0.3, ec='k',
fc=fc, ls='--', zorder=zorder, label=label))

elif style.lower() == 'scatter':
ax.scatter(*points, alpha=0.3, ec=fc, color=fc, marker='+',
zorder=zorder, label=label)

x, y = transform(self.x, self.y)
points = np.vstack((x.ravel(), y.ravel()))
# Remove NaN values from the grid:
points = points[:, ~np.logical_or(*np.isnan(points))]
if style.lower() == 'hull':
if self.retinotopy and self.retinotopy.split_map and not use_dva:
# all split maps have an offset for left fovea
divide = 0 if use_dva else self.retinotopy.left_offset / 2
points_right = points[:, points[0] >= divide]
points_left = points[:, points[0] <= divide]
if points_right.size > 0:
hull_right = ConvexHull(points_right.T)
ax.add_patch(Polygon(points_right[:, hull_right.vertices].T, alpha=0.3, ec='k',
fc=color, ls='--', zorder=zorder))
if points_left.size > 0:
hull_left = ConvexHull(points_left.T)
ax.add_patch(Polygon(points_left[:, hull_left.vertices].T, alpha=0.3, ec='k',
fc=color, ls='--', zorder=zorder))
else:
hull = ConvexHull(points.T)
ax.add_patch(Polygon(points[:, hull.vertices].T, alpha=0.3, ec='k',
fc=color, ls='--', zorder=zorder))
legends.append(ax.patches[-1])
elif style.lower() == 'scatter':
ax.scatter(*points, alpha=0.4, ec=color, color=color, marker='+',
zorder=zorder, label=label)

# This is needed in MPL 3.0.X to set the axis limit correctly:
ax.autoscale_view()

# plot boundary between hemispheres if it exists
# but don't change the plot limits
lim = ax.get_xlim()
if self.retinotopy and self.retinotopy.split_map and hasattr(self.retinotopy, 'left_offset'):
if self.retinotopy and self.retinotopy.split_map:
boundary = self.retinotopy.left_offset / 2
if use_dva:
boundary = 0
ax.axvline(boundary, linestyle=':', c='gray')
ax.set_xlim(lim)

if len(transforms) > 1 or legend:
if style in ['cell', 'hull']:
ax.legend(legends, [t[0] for t in transforms], loc='upper right')
else:
ax.legend(loc='upper right')

return ax


Expand Down
3 changes: 3 additions & 0 deletions pulse2percept/topography/cortex.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
class CorticalMap(VisualFieldMap):
"""Template class for V1/V2/V3 visuotopic maps"""
allowed_regions = {'v1', 'v2', 'v3'}

# All cortical maps are split into 2 hemispheres
split_map = True

def __init__(self, **params):
Expand Down Expand Up @@ -47,6 +49,7 @@ def to_dva(self):
def get_default_params(self):
params = {
'regions' : ['v1'],
# Offset for the left hemisphere fovea
'left_offset' : -20000
}
return params
Expand Down

0 comments on commit 88eb591

Please sign in to comment.