From bbd81daeaf0dac527bd0afa16f70df7c9815c0b5 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 9 Jun 2026 14:15:15 +0200 Subject: [PATCH 1/6] Add the ProbeGroup._global_contact_order concept. --- src/probeinterface/probegroup.py | 145 ++++++++++++++++++++++--------- 1 file changed, 103 insertions(+), 42 deletions(-) diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index d42906a4..e50b3338 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -7,12 +7,21 @@ class ProbeGroup: """ Class to handle a group of Probe objects and the global wiring to a device. - Optionally, it can handle the location of different probes. + Internally, this is represented as a list of Probe object. + + The ProbeGroup is the object saved in the json based probeinterface format, even if there only one probe. + + Tiny detail: when using `PropbeGroup.to_numpy()` / `PropbeGroup.to_dataframe()` by default the contact order + is the "natural" one (stacked order of each probe). But optionally, this order can be more complex, for instance + some contact of each probe are interleaved, in this case a optional reordering can be applied. + + """ def __init__(self): self.probes = [] + self._global_contact_order = None def add_probe(self, probe: Probe) -> None: """ @@ -114,6 +123,9 @@ def to_numpy(self, complete: bool = False) -> np.ndarray: pg_arr.append(arr_ext) pg_arr = np.concatenate(pg_arr, axis=0) + + if self._global_contact_order is not None: + pg_arr = pg_arr[self._global_contact_order] return pg_arr @staticmethod @@ -121,6 +133,10 @@ def from_numpy(arr: np.ndarray) -> "ProbeGroup": """Create ProbeGroup from a complex numpy array see ProbeGroup.to_numpy() + Note that if the contact_vector has several probe and some contact are interleaved, then the ProbeGroup will + have a non natural ordering (contact from probes are stack): in short ProbeGroup._global_contact_order + will be not None. + Parameters ---------- arr : np.array @@ -131,7 +147,13 @@ def from_numpy(arr: np.ndarray) -> "ProbeGroup": probegroup : ProbeGroup The instantiated ProbeGroup object """ - from .probe import Probe + + # Check if contacts are interleaved + num_probes = np.unique(arr["probe_index"]).size + is_interleaved = (num_probes > 1) and np.any(np.diff(arr["probe_index"]) < 0) + print('is_interleaved', is_interleaved) + if is_interleaved: + global_contact_order = [] probes_indices = np.unique(arr["probe_index"]) probegroup = ProbeGroup() @@ -139,6 +161,14 @@ def from_numpy(arr: np.ndarray) -> "ProbeGroup": mask = arr["probe_index"] == probe_index probe = Probe.from_numpy(arr[mask]) probegroup.add_probe(probe) + + if is_interleaved: + global_contact_order.append(np.flatnonzero(mask)) + + if is_interleaved: + # the argsort is for the 'reverse' order! + probegroup._global_contact_order = np.argsort(np.concatenate(global_contact_order)) + return probegroup def to_dataframe(self, complete: bool = False) -> "pandas.DataFrame": @@ -181,6 +211,11 @@ def to_dict(self, array_as_list: bool = False) -> dict: for probe_ind, probe in enumerate(self.probes): probe_dict = probe.to_dict(array_as_list=array_as_list) d["probes"].append(probe_dict) + if self._global_contact_order is not None: + global_contact_order = self._global_contact_order + if array_as_list: + global_contact_order = global_contact_order.to_list() + d["global_contact_order"] = global_contact_order return d @staticmethod @@ -201,6 +236,11 @@ def from_dict(d: dict) -> "ProbeGroup": for probe_dict in d["probes"]: probe = Probe.from_dict(probe_dict) probegroup.add_probe(probe) + + global_contact_order = d.get("global_contact_order", None) + if global_contact_order is not None: + probegroup._global_contact_order = np.asarray(global_contact_order) + return probegroup def get_global_device_channel_indices(self) -> np.ndarray: @@ -226,31 +266,40 @@ def get_global_device_channel_indices(self) -> np.ndarray: channels["device_channel_indices"] = arr["device_channel_indices"] return channels - def set_global_device_channel_indices(self, channels: np.ndarray | list) -> None: + def set_global_device_channel_indices(self, device_channel_indices: np.ndarray | list) -> None: """ - Set global indices for all probes + Set global indices for all probes. + + Important note : if the order of contacts is not "natural" then the device_channel_indices + is applied is the real/reordered contacts vector. In short, the device_channel_indices is ziped to + ProbeGroup.to_numpy() (always ordered). Parameters ---------- channels: np.ndarray | list The device channal indices to be set """ - channels = np.asarray(channels) - if channels.size != self.get_contact_count(): + device_channel_indices = np.asarray(device_channel_indices) + if device_channel_indices.size != self.get_contact_count(): raise ValueError( - f"Wrong channels size {channels.size} for the number of channels {self.get_contact_count()}" + f"Wrong channels size {device_channel_indices.size} for the number of channels {self.get_contact_count()}" ) # first reset previous indices for i, probe in enumerate(self.probes): n = probe.get_contact_count() probe.set_device_channel_indices([-1] * n) + + if self._global_contact_order is not None: + # this is tricky conceptually but needed but needed for consistency + rev_order = np.argsort(self._global_contact_order) + device_channel_indices = device_channel_indices[rev_order] # then set new indices ind = 0 for i, probe in enumerate(self.probes): n = probe.get_contact_count() - probe.set_device_channel_indices(channels[ind : ind + n]) + probe.set_device_channel_indices(device_channel_indices[ind : ind + n]) ind += n def get_global_contact_ids(self) -> np.ndarray: @@ -275,6 +324,8 @@ def get_global_contact_positions(self) -> np.ndarray: An array of the contact positions across all probes """ contact_positions = np.vstack([probe.contact_positions for probe in self.probes]) + if self._global_contact_order is not None: + contact_positions = contact_positions[self._global_contact_order] return contact_positions def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup": @@ -295,55 +346,65 @@ def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup": The sliced probe group """ + # TODO SAM order!!! n = self.get_contact_count() selection = np.asarray(selection) + if selection.dtype.kind not in ("b", "i"): + raise TypeError(f"selection must be bool array or int array, not of type: {type(selection)}") + + if selection.dtype == "bool": assert selection.shape == ( n, ), f"if array of bool given it must be the same size as the number of contacts {selection.shape} != {n}" - (selection_indices,) = np.nonzero(selection) - elif selection.dtype.kind == "i": - assert np.unique(selection).size == selection.size - if len(selection) > 0: - assert ( - 0 <= np.min(selection) < n - ), f"An index within your selection is out of bounds {np.min(selection)}" - assert ( - 0 <= np.max(selection) < n - ), f"An index within your selection is out of bounds {np.max(selection)}" - selection_indices = selection - else: - selection_indices = [] - else: - raise TypeError(f"selection must be bool array or int array, not of type: {type(selection)}") + selection_indices = np.flatnonzero(selection) - if len(selection_indices) == 0: - return ProbeGroup() + if len(selection) == 0: + raise ValueError("ProbeGroup.get_slice() with empty selection is not handled") + # return ProbeGroup() - # Map selection to indices of individual probes - ind = 0 - sliced_probes = [] - for probe in self.probes: - n = probe.get_contact_count() - probe_limits = (ind, ind + n) - ind += n + assert np.unique(selection).size == selection.size + assert ( + 0 <= np.min(selection) < n + ), f"An index within your selection is out of bounds {np.min(selection)}" + assert ( + 0 <= np.max(selection) < n + ), f"An index within your selection is out of bounds {np.max(selection)}" + selection_indices = selection - probe_selection_indices = selection_indices[ - (selection_indices >= probe_limits[0]) & (selection_indices < probe_limits[1]) - ] - if len(probe_selection_indices) == 0: - continue - sliced_probe = probe.get_slice(probe_selection_indices - probe_limits[0]) - sliced_probes.append(sliced_probe) - sliced_probe_group = ProbeGroup() - for probe in sliced_probes: - sliced_probe_group.add_probe(probe) + contact_arr = self.to_numpy(complete=True) + contact_arr = contact_arr[selection] + sliced_probe_group = ProbeGroup.from_numpy(contact_arr) + + # TODO annoatation!! return sliced_probe_group + # # Map selection to indices of individual probes + # ind = 0 + # sliced_probes = [] + # for probe in self.probes: + # n = probe.get_contact_count() + # probe_limits = (ind, ind + n) + # ind += n + + # probe_selection_indices = selection_indices[ + # (selection_indices >= probe_limits[0]) & (selection_indices < probe_limits[1]) + # ] + # if len(probe_selection_indices) == 0: + # continue + # sliced_probe = probe.get_slice(probe_selection_indices - probe_limits[0]) + # sliced_probes.append(sliced_probe) + + # sliced_probe_group = ProbeGroup() + # for probe in sliced_probes: + # sliced_probe_group.add_probe(probe) + + # return sliced_probe_group + def check_global_device_wiring_and_ids(self) -> None: # check unique device_channel_indices for !=-1 chans = self.get_global_device_channel_indices() From 16695d74a5242afedb369f3452dbe6a31b3467b7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 9 Jun 2026 12:27:58 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/probeinterface/probegroup.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index e50b3338..8dc63b7c 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -134,7 +134,7 @@ def from_numpy(arr: np.ndarray) -> "ProbeGroup": see ProbeGroup.to_numpy() Note that if the contact_vector has several probe and some contact are interleaved, then the ProbeGroup will - have a non natural ordering (contact from probes are stack): in short ProbeGroup._global_contact_order + have a non natural ordering (contact from probes are stack): in short ProbeGroup._global_contact_order will be not None. Parameters @@ -151,7 +151,7 @@ def from_numpy(arr: np.ndarray) -> "ProbeGroup": # Check if contacts are interleaved num_probes = np.unique(arr["probe_index"]).size is_interleaved = (num_probes > 1) and np.any(np.diff(arr["probe_index"]) < 0) - print('is_interleaved', is_interleaved) + print("is_interleaved", is_interleaved) if is_interleaved: global_contact_order = [] @@ -164,7 +164,7 @@ def from_numpy(arr: np.ndarray) -> "ProbeGroup": if is_interleaved: global_contact_order.append(np.flatnonzero(mask)) - + if is_interleaved: # the argsort is for the 'reverse' order! probegroup._global_contact_order = np.argsort(np.concatenate(global_contact_order)) @@ -236,7 +236,7 @@ def from_dict(d: dict) -> "ProbeGroup": for probe_dict in d["probes"]: probe = Probe.from_dict(probe_dict) probegroup.add_probe(probe) - + global_contact_order = d.get("global_contact_order", None) if global_contact_order is not None: probegroup._global_contact_order = np.asarray(global_contact_order) @@ -289,7 +289,7 @@ def set_global_device_channel_indices(self, device_channel_indices: np.ndarray | for i, probe in enumerate(self.probes): n = probe.get_contact_count() probe.set_device_channel_indices([-1] * n) - + if self._global_contact_order is not None: # this is tricky conceptually but needed but needed for consistency rev_order = np.argsort(self._global_contact_order) @@ -354,7 +354,6 @@ def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup": if selection.dtype.kind not in ("b", "i"): raise TypeError(f"selection must be bool array or int array, not of type: {type(selection)}") - if selection.dtype == "bool": assert selection.shape == ( n, @@ -363,18 +362,13 @@ def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup": if len(selection) == 0: raise ValueError("ProbeGroup.get_slice() with empty selection is not handled") - # return ProbeGroup() + # return ProbeGroup() assert np.unique(selection).size == selection.size - assert ( - 0 <= np.min(selection) < n - ), f"An index within your selection is out of bounds {np.min(selection)}" - assert ( - 0 <= np.max(selection) < n - ), f"An index within your selection is out of bounds {np.max(selection)}" + assert 0 <= np.min(selection) < n, f"An index within your selection is out of bounds {np.min(selection)}" + assert 0 <= np.max(selection) < n, f"An index within your selection is out of bounds {np.max(selection)}" selection_indices = selection - contact_arr = self.to_numpy(complete=True) contact_arr = contact_arr[selection] sliced_probe_group = ProbeGroup.from_numpy(contact_arr) From 10fac954fca1ce195de0aa4d62491a2f189295f0 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 9 Jun 2026 14:55:43 +0200 Subject: [PATCH 3/6] Add some tests for ordering --- src/probeinterface/probegroup.py | 27 +------------- tests/test_probegroup.py | 63 +++++++++++++++++++++++++++----- 2 files changed, 56 insertions(+), 34 deletions(-) diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index e50b3338..5bc732e6 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -151,7 +151,6 @@ def from_numpy(arr: np.ndarray) -> "ProbeGroup": # Check if contacts are interleaved num_probes = np.unique(arr["probe_index"]).size is_interleaved = (num_probes > 1) and np.any(np.diff(arr["probe_index"]) < 0) - print('is_interleaved', is_interleaved) if is_interleaved: global_contact_order = [] @@ -291,7 +290,7 @@ def set_global_device_channel_indices(self, device_channel_indices: np.ndarray | probe.set_device_channel_indices([-1] * n) if self._global_contact_order is not None: - # this is tricky conceptually but needed but needed for consistency + # this is tricky conceptually but needed for consistency rev_order = np.argsort(self._global_contact_order) device_channel_indices = device_channel_indices[rev_order] @@ -346,7 +345,6 @@ def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup": The sliced probe group """ - # TODO SAM order!!! n = self.get_contact_count() @@ -379,31 +377,10 @@ def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup": contact_arr = contact_arr[selection] sliced_probe_group = ProbeGroup.from_numpy(contact_arr) - # TODO annoatation!! + # TODO annoatation probe per probe!! return sliced_probe_group - # # Map selection to indices of individual probes - # ind = 0 - # sliced_probes = [] - # for probe in self.probes: - # n = probe.get_contact_count() - # probe_limits = (ind, ind + n) - # ind += n - - # probe_selection_indices = selection_indices[ - # (selection_indices >= probe_limits[0]) & (selection_indices < probe_limits[1]) - # ] - # if len(probe_selection_indices) == 0: - # continue - # sliced_probe = probe.get_slice(probe_selection_indices - probe_limits[0]) - # sliced_probes.append(sliced_probe) - - # sliced_probe_group = ProbeGroup() - # for probe in sliced_probes: - # sliced_probe_group.add_probe(probe) - - # return sliced_probe_group def check_global_device_wiring_and_ids(self) -> None: # check unique device_channel_indices for !=-1 diff --git a/tests/test_probegroup.py b/tests/test_probegroup.py index ddd332d4..32ff7d6d 100644 --- a/tests/test_probegroup.py +++ b/tests/test_probegroup.py @@ -6,8 +6,9 @@ import numpy as np -@pytest.fixture -def probegroup(): + + +def _make_probegroup(): """Fixture: a ProbeGroup with 3 probes, each with device channel indices set.""" probegroup = ProbeGroup() nchan = 0 @@ -21,6 +22,11 @@ def probegroup(): return probegroup +@pytest.fixture +def probegroup(): + return _make_probegroup() + + def test_probegroup(probegroup): indices = probegroup.get_global_device_channel_indices() @@ -200,7 +206,7 @@ def test_copy_is_independent(probegroup): np.testing.assert_array_equal(probegroup.probes[0].contact_positions, original_positions) -# ── get_slice() tests ─────────────────────────────────────────────────────── +# ── get_slice() simple : natural order def test_get_slice_by_bool(probegroup): @@ -232,10 +238,10 @@ def test_get_slice_preserves_positions(probegroup): np.testing.assert_array_equal(sliced.get_global_contact_positions(), expected) -def test_get_slice_empty_selection(probegroup): - sliced = probegroup.get_slice(np.array([], dtype=int)) - assert sliced.get_contact_count() == 0 - assert len(sliced.probes) == 0 +# def test_get_slice_empty_selection(probegroup): +# sliced = probegroup.get_slice(np.array([], dtype=int)) +# assert sliced.get_contact_count() == 0 +# assert len(sliced.probes) == 0 def test_get_slice_wrong_bool_size(probegroup): @@ -259,7 +265,46 @@ def test_get_slice_all_contacts(probegroup): probegroup.get_global_contact_positions(), ) +# ── global_contact_order : to_numpy/from_numpy, to_dict/from_dict, get_slice + +def test_reordred_probegroup(probegroup): + order = np.concatenate([np.arange(0, 96, 2), np.arange(95, 0, -2)]) + + contact_vector = probegroup.to_numpy(complete=True) + contact_vector = contact_vector[order] + + probegroup2 = ProbeGroup.from_numpy(contact_vector) + assert probegroup2._global_contact_order is not None + contact_vector2 = probegroup2.to_numpy(complete=True) + assert np.array_equal(contact_vector, contact_vector2) + + probegroup3 = ProbeGroup.from_dict(probegroup2.to_dict()) + assert probegroup3._global_contact_order is not None + contact_vector3 = probegroup3.to_numpy(complete=True) + assert np.array_equal(contact_vector2, contact_vector3) + + probegroup4 = probegroup.get_slice(order) + assert probegroup4._global_contact_order is not None + contact_vector4 = probegroup4.to_numpy(complete=True) + assert np.array_equal(contact_vector3, contact_vector4) + + probegroup5 = ProbeGroup.from_dict(probegroup4.to_dict()) + assert probegroup5._global_contact_order is not None + contact_vector5 = probegroup3.to_numpy(complete=True) + assert np.array_equal(contact_vector4, contact_vector5) + + # let go back to original order + rev_order = np.argsort(order) + probegroup6 = probegroup5.get_slice(rev_order) + assert probegroup6._global_contact_order is None + + if __name__ == "__main__": - test_probegroup() - # ~ test_probegroup_3d() + probegroup = _make_probegroup() + + # test_probegroup(probegroup) + # test_probegroup_3d() + test_reordred_probegroup(probegroup) + + From dbd0455d85ebd7c97b556557202a4ad76f3b6b99 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 9 Jun 2026 12:56:54 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/probeinterface/probegroup.py | 1 - tests/test_probegroup.py | 15 ++++++--------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index c61d3ee5..9658b978 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -375,7 +375,6 @@ def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup": return sliced_probe_group - def check_global_device_wiring_and_ids(self) -> None: # check unique device_channel_indices for !=-1 chans = self.get_global_device_channel_indices() diff --git a/tests/test_probegroup.py b/tests/test_probegroup.py index 32ff7d6d..089c642a 100644 --- a/tests/test_probegroup.py +++ b/tests/test_probegroup.py @@ -6,8 +6,6 @@ import numpy as np - - def _make_probegroup(): """Fixture: a ProbeGroup with 3 probes, each with device channel indices set.""" probegroup = ProbeGroup() @@ -265,14 +263,16 @@ def test_get_slice_all_contacts(probegroup): probegroup.get_global_contact_positions(), ) + # ── global_contact_order : to_numpy/from_numpy, to_dict/from_dict, get_slice + def test_reordred_probegroup(probegroup): order = np.concatenate([np.arange(0, 96, 2), np.arange(95, 0, -2)]) - + contact_vector = probegroup.to_numpy(complete=True) contact_vector = contact_vector[order] - + probegroup2 = ProbeGroup.from_numpy(contact_vector) assert probegroup2._global_contact_order is not None contact_vector2 = probegroup2.to_numpy(complete=True) @@ -296,15 +296,12 @@ def test_reordred_probegroup(probegroup): # let go back to original order rev_order = np.argsort(order) probegroup6 = probegroup5.get_slice(rev_order) - assert probegroup6._global_contact_order is None - + assert probegroup6._global_contact_order is None if __name__ == "__main__": probegroup = _make_probegroup() - + # test_probegroup(probegroup) # test_probegroup_3d() test_reordred_probegroup(probegroup) - - From bfd44b8d939305959751c19599b468b26a4de205 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 9 Jun 2026 15:03:34 +0200 Subject: [PATCH 5/6] oups --- src/probeinterface/probegroup.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index c61d3ee5..941f2df9 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -356,7 +356,7 @@ def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup": assert selection.shape == ( n, ), f"if array of bool given it must be the same size as the number of contacts {selection.shape} != {n}" - selection_indices = np.flatnonzero(selection) + selection = np.flatnonzero(selection) if len(selection) == 0: raise ValueError("ProbeGroup.get_slice() with empty selection is not handled") @@ -365,7 +365,6 @@ def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup": assert np.unique(selection).size == selection.size assert 0 <= np.min(selection) < n, f"An index within your selection is out of bounds {np.min(selection)}" assert 0 <= np.max(selection) < n, f"An index within your selection is out of bounds {np.max(selection)}" - selection_indices = selection contact_arr = self.to_numpy(complete=True) contact_arr = contact_arr[selection] From cc7320be0901bf056939187b47919a2795ec11ec Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Tue, 9 Jun 2026 16:10:40 +0200 Subject: [PATCH 6/6] Update src/probeinterface/probegroup.py Co-authored-by: Alessio Buccino --- src/probeinterface/probegroup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index 119dcf5f..26f23bff 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -270,7 +270,7 @@ def set_global_device_channel_indices(self, device_channel_indices: np.ndarray | Set global indices for all probes. Important note : if the order of contacts is not "natural" then the device_channel_indices - is applied is the real/reordered contacts vector. In short, the device_channel_indices is ziped to + is applied is the real/reordered contacts vector. In short, the device_channel_indices is zipped to ProbeGroup.to_numpy() (always ordered). Parameters