Skip to content

Commit

Permalink
[ENH] Add methods to initialize an EnsembleImplant from a list of coo…
Browse files Browse the repository at this point in the history
…rdinates (#601)

* add from_coords method to initialize ensemble

* add typecheck for prosthesissystem

* update ensemble from_coords, add from_cortical_map to match from_neuropythy
  • Loading branch information
tallyhawley committed Jan 26, 2024
1 parent 6be2f3f commit ebb1d33
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 1 deletion.
102 changes: 102 additions & 0 deletions pulse2percept/implants/ensemble.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""`EnsembleImplant`"""
import numpy as np
from .base import ProsthesisSystem
from .electrodes import Electrode
from .electrode_arrays import ElectrodeArray
Expand All @@ -9,6 +10,107 @@ class EnsembleImplant(ProsthesisSystem):
# Frozen class: User cannot add more class attributes
__slots__ = ('_implants', '_earray', '_stim', 'safe_mode', 'preprocess')

@classmethod
def from_cortical_map(cls, implant_type, vfmap, locs=None, xrange=None, yrange=None, xystep=None,
region='v1'):
"""
Create an ensemble implant from a cortical visual field map.
The implant will be created by creating an implant of type `implant_type`
for each visual field location specified either by locs or by xrange, yrange,
and xystep. Each implant will be centered at the given location.
Parameters
----------
vfmap : p2p.topography.CorticalMap
Visual field map to create implant from.
implant_type : type
Type of implant to create for the ensemble. Must subclass
p2p.implants.ProsthesisSystem
locs : np.ndarray with shape (n, 2), optional
Array of visual field locations to create implants at (dva).
Not needed if using xrange, yrange, and xystep.
xrange, yrange: tuple of floats, optional
Range of x and y coordinates (dva) to create implants at.
xystep : float, optional
Spacing between implant centers.
region : str, optional
Region of cortex to create implant in.
Returns
-------
ensemble : p2p.implants.EnsembleImplant
Ensemble implant created from the cortical visual field map.
"""
from ..topography import CorticalMap, Grid2D
if not isinstance(vfmap, CorticalMap):
raise TypeError("vfmap must be a p2p.topography.CorticalMap")
if not issubclass(implant_type, ProsthesisSystem):
raise TypeError("implant_type must be a sub-type of ProsthesisSystem")

if locs is None:
if xrange is None:
xrange = (-3, 3)
if yrange is None:
yrange = (-3, 3)
if xystep is None:
xystep = 1

# make a grid of points
grid = Grid2D(xrange, yrange, xystep)
xlocs = grid.x.flatten()
ylocs = grid.y.flatten()
else:
xlocs = locs[:, 0]
ylocs = locs[:, 1]

implant_locations = np.array(vfmap.from_dva()[region](xlocs, ylocs)).T

return cls.from_coords(implant_type=implant_type, locs=implant_locations)


@classmethod
def from_coords(cls, implant_type, locs=None, xrange=None, yrange=None, xystep=None):
"""
Create an ensemble implant using physical (cortical or retinal) coordinates.
Parameters
----------
implant_type : type
The type of implant to create for the ensemble.
locs : np.ndarray with shape (n, 2), optional
Array of physical locations (um) to create implants at. Not
needed if using xrange, yrange, and xystep.
xrange, yrange: tuple of floats, optional
Range of x and y coordinates to create implants at.
xystep : float, optional
Spacing between implant centers.
"""
from ..topography import Grid2D

if not issubclass(implant_type, ProsthesisSystem):
raise TypeError("implant_type must be a sub-type of ProsthesisSystem")

if locs is None:
if xrange is None:
xrange = (-3, 3)
if yrange is None:
yrange = (-3, 3)
if xystep is None:
xystep = 1

# make a grid of points
grid = Grid2D(xrange, yrange, xystep)
xlocs = grid.x.flatten()
ylocs = grid.y.flatten()
else:
xlocs = locs[:, 0]
ylocs = locs[:, 1]

implant_list = [implant_type(x=x, y=y) for x,y in zip(xlocs, ylocs)]

return cls(implant_list)

def __init__(self, implants, stim=None, preprocess=False,safe_mode=False):
"""Ensemble implant
Expand Down
58 changes: 57 additions & 1 deletion pulse2percept/implants/tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from pulse2percept.implants import (EnsembleImplant, PointSource, ProsthesisSystem)
from pulse2percept.implants.cortex import Cortivis
from pulse2percept.topography import Polimeni2006Map
from pulse2percept.models.cortex.base import ScoreboardModel

def test_EnsembleImplant():
Expand Down Expand Up @@ -50,4 +51,59 @@ def test_ensemble_cortivis():
npt.assert_equal(ensemble['0-1'].x, cortivis0['1'].x)
npt.assert_equal(ensemble['0-1'].y, cortivis0['1'].y)
npt.assert_equal(ensemble['1-1'].x, cortivis1['1'].x)
npt.assert_equal(ensemble['1-1'].y, cortivis1['1'].y)
npt.assert_equal(ensemble['1-1'].y, cortivis1['1'].y)

# test from_coords initialization (physical coords in um)
def test_from_coords():
locs = np.array([(0,0), (10000,0)])

# check invalid instantiations
with pytest.raises(TypeError):
EnsembleImplant.from_coords(Cortivis(0), locs=locs)

locs = np.array([(0,0), (10000,0), (0, 10000)])

c0 = Cortivis(x=0,y=0)
c1 = Cortivis(x=10000,y=0)
c2 = Cortivis(x=0, y=10000)
ensemble = EnsembleImplant.from_coords(Cortivis, locs=locs)

# check that positions are the same
npt.assert_equal(ensemble['0-1'].x, c0['1'].x)
npt.assert_equal(ensemble['0-1'].y, c0['1'].y)
npt.assert_equal(ensemble['0-1'].z, c0['1'].z)
npt.assert_equal(ensemble['1-1'].x, c1['1'].x)
npt.assert_equal(ensemble['1-1'].y, c1['1'].y)
npt.assert_equal(ensemble['1-1'].z, c1['1'].z)
npt.assert_equal(ensemble['2-1'].x, c2['1'].x)
npt.assert_equal(ensemble['2-1'].y, c2['1'].y)
npt.assert_equal(ensemble['2-1'].z, c2['1'].z)

# test from_cortical_map initialization (vf coords in dva)
def test_from_cortical_map():
vfmap = Polimeni2006Map()

locs = np.array([(2000,2000), (10000,0), (5000, 5000)]).astype(np.float64)

# find locations in dva
dva_x, dva_y = vfmap.to_dva()['v1'](locs[:,0], locs[:,1])
dva_list = [(x,y) for x,y in zip(dva_x, dva_y)]
dva_locs = np.array(dva_list)

c0 = Cortivis(x=2000, y=2000)
c1 = Cortivis(x=10000, y=0)
c2 = Cortivis(x=5000, y=5000)

# use dva coords to create ensemble
ensemble = EnsembleImplant.from_cortical_map(Cortivis, vfmap, dva_locs)

# check that positions are approx. the same
npt.assert_approx_equal(ensemble['0-1'].x, c0['1'].x, 5)
npt.assert_approx_equal(ensemble['0-1'].y, c0['1'].y, 5)
npt.assert_approx_equal(ensemble['0-1'].z, c0['1'].z, 5)
npt.assert_approx_equal(ensemble['1-1'].x, c1['1'].x, 5)
npt.assert_approx_equal(ensemble['1-1'].y, c1['1'].y, 5)
npt.assert_approx_equal(ensemble['1-1'].z, c1['1'].z, 5)
npt.assert_approx_equal(ensemble['2-1'].x, c2['1'].x, 5)
npt.assert_approx_equal(ensemble['2-1'].y, c2['1'].y, 5)
npt.assert_approx_equal(ensemble['2-1'].z, c2['1'].z, 5)

0 comments on commit ebb1d33

Please sign in to comment.