-
Notifications
You must be signed in to change notification settings - Fork 51
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add EnsembleImplant * add predict_percept smoke test for ensemble * add invalid instantiation tests * Update ensemble.py --------- Co-authored-by: Jacob Granley <jgranley@ucsb.edu>
- Loading branch information
1 parent
985d539
commit d2f0346
Showing
3 changed files
with
124 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
"""`EnsembleImplant`""" | ||
from .base import ProsthesisSystem | ||
from .electrodes import Electrode | ||
from .electrode_arrays import ElectrodeArray | ||
|
||
class EnsembleImplant(ProsthesisSystem): | ||
"""Ensemble implant | ||
An ensemble implant combines multiple implants into one larger electrode array | ||
for the purpose of modeling tandem implants, e.g. CORTIVIS, ICVP | ||
Parameters | ||
---------- | ||
implants : list or dict | ||
A list or dict of implants to be combined. | ||
stim : :py:class:`~pulse2percept.stimuli.Stimulus` source type | ||
A valid source type for the :py:class:`~pulse2percept.stimuli.Stimulus` | ||
object (e.g., scalar, NumPy array, pulse train). | ||
preprocess : bool or callable, optional | ||
Either True/False to indicate whether to execute the implant's default | ||
preprocessing method whenever a new stimulus is assigned, or a custom | ||
function (callable). | ||
safe_mode : bool, optional | ||
If safe mode is enabled, only charge-balanced stimuli are allowed. | ||
""" | ||
|
||
# Frozen class: User cannot add more class attributes | ||
__slots__ = ('_implants', '_earray', '_stim', 'safe_mode', 'preprocess') | ||
|
||
def __init__(self, implants, stim=None, preprocess=False,safe_mode=False): | ||
self.implants = implants | ||
self.safe_mode = safe_mode | ||
self.preprocess = preprocess | ||
self.stim = stim | ||
|
||
def _pprint_params(self): | ||
"""Return dict of class attributes to pretty-print""" | ||
return {'implants': self.implants, 'earray': self.earray, 'stim': self.stim, | ||
'safe_mode': self.safe_mode, 'preprocess': self.preprocess} | ||
|
||
@property | ||
def implants(self): | ||
"""Dict of implants | ||
""" | ||
return self._implants | ||
|
||
@implants.setter | ||
def implants(self, implants): | ||
"""Implant dict setter (called upon ``self.implants = implants``)""" | ||
# Assign the implant dict: | ||
if isinstance(implants, list): | ||
if not all(isinstance(implant, ProsthesisSystem) for implant in implants): | ||
raise TypeError(f"All elements in 'implants' must be ProsthesisSystem objects.") | ||
self._implants = {i:implant for i,implant in enumerate(implants)} | ||
elif isinstance(implants, dict): | ||
if not all(isinstance(implant, ProsthesisSystem) for implant in implants.values()): | ||
raise TypeError(f"All elements in 'implants' must be ProsthesisSystem objects.") | ||
self._implants = implants.copy() | ||
else: | ||
raise TypeError(f"'implants' must be a list or a dict object, not " | ||
f"{type(implants)}.") | ||
# Create the electrode array | ||
electrodes = {} | ||
for i, implant in self._implants.items(): | ||
for name, electrode in implant.earray.electrodes.items(): | ||
electrodes[str(i) + "-" + str(name)] = electrode | ||
self._earray = ElectrodeArray(electrodes) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import numpy as np | ||
import numpy.testing as npt | ||
import pytest | ||
from pulse2percept.implants import (EnsembleImplant, PointSource, ProsthesisSystem) | ||
from pulse2percept.implants.cortex import Cortivis | ||
from pulse2percept.models.cortex.base import ScoreboardModel | ||
|
||
def test_EnsembleImplant(): | ||
# Invalid instantiations: | ||
with pytest.raises(TypeError): | ||
EnsembleImplant(implants="this can't happen") | ||
with pytest.raises(TypeError): | ||
EnsembleImplant(implants=[3,Cortivis()]) | ||
with pytest.raises(TypeError): | ||
EnsembleImplant(implants={'1': Cortivis(), '2': 'abcd'}) | ||
|
||
# Instantiate with list | ||
p1 = ProsthesisSystem(PointSource(0,0,0)) | ||
p2 = ProsthesisSystem(PointSource(1,1,1)) | ||
ensemble = EnsembleImplant(implants=[p1,p2]) | ||
npt.assert_equal(ensemble.n_electrodes, 2) | ||
npt.assert_equal(ensemble[0], p1[0]) | ||
npt.assert_equal(ensemble[1], p2[0]) | ||
npt.assert_equal(ensemble.electrode_names, ['0-0','1-0']) | ||
|
||
# Instantiate with dict | ||
ensemble = EnsembleImplant(implants={'A': p2, 'B': p1}) | ||
npt.assert_equal(ensemble.n_electrodes, 2) | ||
npt.assert_equal(ensemble[0], p2[0]) | ||
npt.assert_equal(ensemble[1], p1[0]) | ||
npt.assert_equal(ensemble.electrode_names, ['A-0','B-0']) | ||
|
||
# predict_percept smoke test | ||
ensemble.stim = [1,1] | ||
model = ScoreboardModel().build() | ||
model.predict_percept(ensemble) | ||
|
||
# we essentially just need to make sure that electrode names are | ||
# set properly, the rest of the EnsembleImplant functionality | ||
# (electrode placement, etc) is determined by the implants passed in | ||
# and thus already tested | ||
# but we'll test it again just to make sure | ||
def test_ensemble_cortivis(): | ||
cortivis0 = Cortivis(0) | ||
cortivis1 = Cortivis(x=10000) | ||
|
||
ensemble = EnsembleImplant([cortivis0, cortivis1]) | ||
|
||
# check that positions are the same | ||
npt.assert_equal(ensemble['0-1'].x, cortivis0['1'].x) | ||
npt.assert_equal(ensemble['0-1'].y, cortivis0['1'].y) | ||
npt.assert_equal(ensemble['1-1'].x, cortivis1['1'].x) | ||
npt.assert_equal(ensemble['1-1'].y, cortivis1['1'].y) |