Skip to content

Commit

Permalink
[FIX] Error message for number of electrodes in implant.stim (#359)
Browse files Browse the repository at this point in the history
* fix number of electrodes

* fix error message

* add test
  • Loading branch information
mbeyeler committed Apr 4, 2021
1 parent ccc8bfb commit a6f83b6
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
5 changes: 5 additions & 0 deletions pulse2percept/implants/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,11 @@ def stim(self, data):
# Use electrode names as stimulus coordinates:
stim = Stimulus(data, electrodes=list(self.earray.keys()))

if len(stim.electrodes) > self.n_electrodes:
raise ValueError("Number of electrodes in the stimulus (%d) "
"does not match the number of electrodes in "
"the implant (%d)." % (len(stim.electrodes),
self.n_electrodes))
# Make sure all electrode names are valid:
for electrode in stim.electrodes:
# Invalid index will return None:
Expand Down
9 changes: 8 additions & 1 deletion pulse2percept/implants/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy.testing as npt
from matplotlib.patches import Circle

from pulse2percept.implants import (PointSource, ElectrodeArray,
from pulse2percept.implants import (PointSource, ElectrodeArray, ElectrodeGrid,
ProsthesisSystem)
from pulse2percept.stimuli import Stimulus

Expand Down Expand Up @@ -53,3 +53,10 @@ def test_ProsthesisSystem():
# Slots:
npt.assert_equal(hasattr(implant, '__slots__'), True)
npt.assert_equal(hasattr(implant, '__dict__'), False)


def test_ProsthesisSystem_stim():
implant = ProsthesisSystem(ElectrodeGrid((13, 13), 20))
stim = Stimulus(np.ones((13 * 13 + 1, 5)))
with pytest.raises(ValueError):
implant.stim = stim

0 comments on commit a6f83b6

Please sign in to comment.