Skip to content

Commit

Permalink
[ENH] Add Ensemble implant (#537)
Browse files Browse the repository at this point in the history
* add EnsembleImplant

* add predict_percept smoke test for ensemble

* add invalid instantiation tests

* Update ensemble.py

---------

Co-authored-by: Jacob Granley <jgranley@ucsb.edu>
  • Loading branch information
tallyhawley and jgranley committed Apr 10, 2023
1 parent 985d539 commit d2f0346
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 0 deletions.
3 changes: 3 additions & 0 deletions pulse2percept/implants/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
bvt
imie
prima
ensemble
.. seealso::
Expand All @@ -28,6 +29,7 @@
from .bvt import BVT24, BVT44
from .prima import PhotovoltaicPixel, PRIMA, PRIMA75, PRIMA55, PRIMA40
from .imie import IMIE
from .ensemble import EnsembleImplant
from . import cortex

__all__ = [
Expand All @@ -42,6 +44,7 @@
'Electrode',
'ElectrodeArray',
'ElectrodeGrid',
'EnsembleImplant',
'HexElectrode',
'PhotovoltaicPixel',
'PointSource',
Expand Down
68 changes: 68 additions & 0 deletions pulse2percept/implants/ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""`EnsembleImplant`"""
from .base import ProsthesisSystem
from .electrodes import Electrode
from .electrode_arrays import ElectrodeArray

class EnsembleImplant(ProsthesisSystem):
"""Ensemble implant
An ensemble implant combines multiple implants into one larger electrode array
for the purpose of modeling tandem implants, e.g. CORTIVIS, ICVP
Parameters
----------
implants : list or dict
A list or dict of implants to be combined.
stim : :py:class:`~pulse2percept.stimuli.Stimulus` source type
A valid source type for the :py:class:`~pulse2percept.stimuli.Stimulus`
object (e.g., scalar, NumPy array, pulse train).
preprocess : bool or callable, optional
Either True/False to indicate whether to execute the implant's default
preprocessing method whenever a new stimulus is assigned, or a custom
function (callable).
safe_mode : bool, optional
If safe mode is enabled, only charge-balanced stimuli are allowed.
"""

# Frozen class: User cannot add more class attributes
__slots__ = ('_implants', '_earray', '_stim', 'safe_mode', 'preprocess')

def __init__(self, implants, stim=None, preprocess=False,safe_mode=False):
self.implants = implants
self.safe_mode = safe_mode
self.preprocess = preprocess
self.stim = stim

def _pprint_params(self):
"""Return dict of class attributes to pretty-print"""
return {'implants': self.implants, 'earray': self.earray, 'stim': self.stim,
'safe_mode': self.safe_mode, 'preprocess': self.preprocess}

@property
def implants(self):
"""Dict of implants
"""
return self._implants

@implants.setter
def implants(self, implants):
"""Implant dict setter (called upon ``self.implants = implants``)"""
# Assign the implant dict:
if isinstance(implants, list):
if not all(isinstance(implant, ProsthesisSystem) for implant in implants):
raise TypeError(f"All elements in 'implants' must be ProsthesisSystem objects.")
self._implants = {i:implant for i,implant in enumerate(implants)}
elif isinstance(implants, dict):
if not all(isinstance(implant, ProsthesisSystem) for implant in implants.values()):
raise TypeError(f"All elements in 'implants' must be ProsthesisSystem objects.")
self._implants = implants.copy()
else:
raise TypeError(f"'implants' must be a list or a dict object, not "
f"{type(implants)}.")
# Create the electrode array
electrodes = {}
for i, implant in self._implants.items():
for name, electrode in implant.earray.electrodes.items():
electrodes[str(i) + "-" + str(name)] = electrode
self._earray = ElectrodeArray(electrodes)
53 changes: 53 additions & 0 deletions pulse2percept/implants/tests/test_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import numpy as np
import numpy.testing as npt
import pytest
from pulse2percept.implants import (EnsembleImplant, PointSource, ProsthesisSystem)
from pulse2percept.implants.cortex import Cortivis
from pulse2percept.models.cortex.base import ScoreboardModel

def test_EnsembleImplant():
# Invalid instantiations:
with pytest.raises(TypeError):
EnsembleImplant(implants="this can't happen")
with pytest.raises(TypeError):
EnsembleImplant(implants=[3,Cortivis()])
with pytest.raises(TypeError):
EnsembleImplant(implants={'1': Cortivis(), '2': 'abcd'})

# Instantiate with list
p1 = ProsthesisSystem(PointSource(0,0,0))
p2 = ProsthesisSystem(PointSource(1,1,1))
ensemble = EnsembleImplant(implants=[p1,p2])
npt.assert_equal(ensemble.n_electrodes, 2)
npt.assert_equal(ensemble[0], p1[0])
npt.assert_equal(ensemble[1], p2[0])
npt.assert_equal(ensemble.electrode_names, ['0-0','1-0'])

# Instantiate with dict
ensemble = EnsembleImplant(implants={'A': p2, 'B': p1})
npt.assert_equal(ensemble.n_electrodes, 2)
npt.assert_equal(ensemble[0], p2[0])
npt.assert_equal(ensemble[1], p1[0])
npt.assert_equal(ensemble.electrode_names, ['A-0','B-0'])

# predict_percept smoke test
ensemble.stim = [1,1]
model = ScoreboardModel().build()
model.predict_percept(ensemble)

# we essentially just need to make sure that electrode names are
# set properly, the rest of the EnsembleImplant functionality
# (electrode placement, etc) is determined by the implants passed in
# and thus already tested
# but we'll test it again just to make sure
def test_ensemble_cortivis():
cortivis0 = Cortivis(0)
cortivis1 = Cortivis(x=10000)

ensemble = EnsembleImplant([cortivis0, cortivis1])

# check that positions are the same
npt.assert_equal(ensemble['0-1'].x, cortivis0['1'].x)
npt.assert_equal(ensemble['0-1'].y, cortivis0['1'].y)
npt.assert_equal(ensemble['1-1'].x, cortivis1['1'].x)
npt.assert_equal(ensemble['1-1'].y, cortivis1['1'].y)

0 comments on commit d2f0346

Please sign in to comment.