Skip to content
Open
79 changes: 53 additions & 26 deletions src/spikeinterface/extractors/phykilosortextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
ComputeTemplates,
create_sorting_analyzer,
SortingAnalyzer,
aggregate_channels,
)
from spikeinterface.core.core_tools import define_function_from_class

from spikeinterface.postprocessing import ComputeSpikeAmplitudes, ComputeSpikeLocations
from probeinterface import read_prb, Probe
from probeinterface import read_prb, Probe, ProbeGroup


class BasePhyKilosortSortingExtractor(BaseSorting):
Expand Down Expand Up @@ -314,7 +315,9 @@ def __init__(self, folder_path: Path | str, keep_good_only: bool = False, remove
read_kilosort = define_function_from_class(source_class=KiloSortSortingExtractor, name="read_kilosort")


def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offset_to_uV=None) -> SortingAnalyzer:
def read_kilosort_as_analyzer(
folder_path, recording=None, unwhiten=True, gain_to_uV=None, offset_to_uV=None
) -> SortingAnalyzer:
"""
Load Kilosort output into a SortingAnalyzer. Output from Kilosort version 4.1 and
above are supported. The function may work on older versions of Kilosort output,
Expand All @@ -324,6 +327,8 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offse
----------
folder_path : str or Path
Path to the output Phy folder (containing the params.py).
recording : BaseRecording
A spikeinterface Recording object which will be attached to the analyzer
unwhiten : bool, default: True
Unwhiten the templates computed by kilosort.
gain_to_uV : float | None, default: None
Expand Down Expand Up @@ -353,31 +358,56 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offse
sorting = read_phy(phy_path)
sampling_frequency = sorting.sampling_frequency

# kilosort occasionally contains a few spikes just beyond the recording end point, which can lead
# to errors later. To avoid this, we pad the recording with an extra second of blank time.
duration = sorting.segments[0]._all_spikes[-1] / sampling_frequency + 1

if (phy_path / "probe.prb").is_file():
probegroup = read_prb(phy_path / "probe.prb")
if len(probegroup.probes) > 0:
warnings.warn("Found more than one probe. Selecting the first probe in ProbeGroup.")
probe = probegroup.probes[0]
elif (phy_path / "channel_positions.npy").is_file():
probe = Probe(si_units="um")
channel_positions = np.load(phy_path / "channel_positions.npy")
probe.set_contacts(channel_positions)
probe.set_device_channel_indices(range(probe.get_contact_count()))
channel_map = np.load(phy_path / "channel_map.npy")
probe.set_device_channel_indices(channel_map)

probegroup = ProbeGroup()
probegroup.add_probe(probe)
else:
AssertionError(f"Cannot read probe layout from folder {phy_path}.")

# to make the initial analyzer, we'll use a fake recording and set it to None later
recording, _ = generate_ground_truth_recording(
probe=probe,
sampling_frequency=sampling_frequency,
durations=[duration],
num_units=1,
seed=1205,
)
if recording is not None:
# Re-wire recording to match the output from the kilosort probe
user_gave_recording = True
all_contact_positions = np.vstack([probe.contact_positions for probe in probegroup.probes])

new_device_channel_indices = []
for recording_channel_location in recording.get_channel_locations():
for channel_index, probe_contact_position in enumerate(all_contact_positions):
if np.all(recording_channel_location == probe_contact_position):
new_device_channel_indices.append(channel_index)
break

if len(new_device_channel_indices) != len(all_contact_positions):
Comment thread
chrishalcrow marked this conversation as resolved.
raise ValueError("The channel locations in your `recording` and the probe channel locations do not match.")

recording.get_probegroup().set_global_device_channel_indices(new_device_channel_indices)

else:
user_gave_recording = False

# kilosort occasionally contains a few spikes just beyond the recording end point, which can lead
# to errors later. To avoid this, we pad the recording with an extra second of blank time.
duration = sorting.segments[0]._all_spikes[-1] / sampling_frequency + 1

# to make the initial analyzer, we'll use a fake recording and set it to None later
recordings = []
for probe in probegroup.probes:
one_recording, _ = generate_ground_truth_recording(
probe=probe,
sampling_frequency=sampling_frequency,
durations=[duration],
num_units=1,
seed=1205,
)
recordings.append(one_recording)
recording = aggregate_channels(recordings)

sparsity = _make_sparsity_from_templates(sorting, recording, phy_path)

Expand All @@ -397,7 +427,9 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offse
)
_make_locations(sorting_analyzer, phy_path)

sorting_analyzer._recording = None
if not user_gave_recording:
sorting_analyzer._recording = None

return sorting_analyzer


Expand All @@ -413,14 +445,9 @@ def _make_locations(sorting_analyzer, kilosort_output_path):
else:
return

# Check that the spike locations vector is the same size as the spike vector
# When recording is given, need to trim spike locations to match spikes in sorting
num_spikes = len(sorting_analyzer.sorting.to_spike_vector())
num_spike_locs = len(locs_np)
if num_spikes != num_spike_locs:
warnings.warn(
"The number of spikes does not match the number of spike locations in `spike_positions.npy`. Skipping spike locations."
)
return
locs_np = locs_np[:num_spikes]

num_dims = len(locs_np[0])
column_names = ["x", "y", "z"][:num_dims]
Expand Down
Loading