Source code for miv_simulator.spikedata

from typing import Dict, List

import copy
from collections import defaultdict

import numpy as np
from miv_simulator.utils import (
    AbstractEnv,
    Struct,
    baks,
    get_module_logger,
    get_trial_time_ranges,
)
from neuroh5.io import scatter_read_cell_attributes, write_cell_attributes
from numpy import ndarray, uint32

## This logger will inherit its setting from its root logger
## which is created in module env
logger = get_module_logger(__name__)

default_baks_analysis_options = Struct(
    **{"BAKS Alpha": 4.77, "BAKS Beta": None}
)


[docs]def get_env_spike_dict( env: AbstractEnv, include_artificial: None = True ) -> Dict[str, Dict[uint32, List[ndarray]]]: """ Constructs a dictionary with per-gid per-trial spike times from the output vectors with spike times and gids contained in env. """ equilibration_duration = float( env.stimulus_config["Equilibration Duration"] ) n_trials = env.n_trials t_vec = np.array(env.t_vec.to_python(), dtype=np.float32) id_vec = np.array(env.id_vec.to_python(), dtype=np.uint32) trial_time_ranges = get_trial_time_ranges( env.t_rec.to_python(), env.n_trials ) trial_time_bins = [ t_trial_start for t_trial_start, t_trial_end in trial_time_ranges ] trial_dur = np.asarray( [env.tstop + equilibration_duration] * n_trials, dtype=np.float32 ) binlst = [] typelst = sorted(env.celltypes.keys()) binvect = np.asarray([env.celltypes[k]["start"] for k in typelst]) sort_idx = np.argsort(binvect, axis=0) pop_names = [typelst[i] for i in sort_idx] bins = binvect[sort_idx][1:] inds = np.digitize(id_vec, bins) pop_spkdict = {} for i, pop_name in enumerate(pop_names): spkdict = {} sinds = np.where(inds == i) if len(sinds) > 0: ids = id_vec[sinds] ts = t_vec[sinds] for j in range(0, len(ids)): gid = ids[j] t = ts[j] if (not include_artificial) and ( gid in env.artificial_cells[pop_name] ): continue if gid in spkdict: spkdict[gid].append(t) else: spkdict[gid] = [t] for gid in spkdict: spiketrain = np.array(spkdict[gid], dtype=np.float32) if gid in env.spike_onset_delay: spiketrain -= env.spike_onset_delay[gid] trial_bins = np.digitize(spiketrain, trial_time_bins) - 1 trial_spikes = [ np.copy(spiketrain[np.where(trial_bins == trial_i)[0]]) for trial_i in range(env.n_trials) ] for trial_i, trial_spiketrain in enumerate(trial_spikes): trial_spiketrain -= ( np.sum(trial_dur[:(trial_i)]) + equilibration_duration ) spkdict[gid] = trial_spikes pop_spkdict[pop_name] = spkdict return pop_spkdict
[docs]def read_spike_events( input_file, population_names, namespace_id, spike_train_attr_name="t", time_range=None, max_spikes=None, n_trials=-1, merge_trials=False, comm=None, io_size=0, include_artificial=True, ): """ Reads spike trains from a NeuroH5 file, and returns a dictionary with spike times and cell indices. :param input_file: str (path to file) :param population_names: list of str :param namespace_id: str :param spike_train_attr_name: str :param time_range: list of float :param max_spikes: float :param n_trials: int :param merge_trials: bool :return: dict """ assert (n_trials >= 1) | (n_trials == -1) trial_index_attr = "Trial Index" trial_dur_attr = "Trial Duration" artificial_attr = "artificial" spkpoplst = [] spkindlst = [] spktlst = [] spktrials = [] num_cell_spks = {} pop_active_cells = {} tmin = float("inf") tmax = 0.0 for pop_name in population_names: if time_range is None or time_range[1] is None: logger.info( f"Reading spike data for population {pop_name} namespace {namespace_id}..." ) else: logger.info( f"Reading spike data for population {pop_name} namespace {namespace_id} in time range {time_range}..." ) spike_train_attr_set = { spike_train_attr_name, trial_index_attr, trial_dur_attr, artificial_attr, } spkiter_dict = scatter_read_cell_attributes( input_file, pop_name, namespaces=[namespace_id], mask=spike_train_attr_set, comm=comm, io_size=io_size, ) spkiter = spkiter_dict[namespace_id] this_num_cell_spks = 0 active_set = set() pop_spkindlst = [] pop_spktlst = [] pop_spktriallst = [] logger.info(f"Read spike cell attributes for population {pop_name}...") # Time Range if time_range is not None: if time_range[0] is None: time_range[0] = 0.0 for spkind, spkattrs in spkiter: is_artificial_flag = spkattrs.get(artificial_attr, None) is_artificial = ( (is_artificial_flag[0] > 0) if is_artificial_flag is not None else None ) if is_artificial is not None: if is_artificial and (not include_artificial): continue slen = len(spkattrs[spike_train_attr_name]) trial_dur = spkattrs.get(trial_dur_attr, np.asarray([0.0])) trial_ind = spkattrs.get( trial_index_attr, np.zeros((slen,), dtype=np.uint8) ) if n_trials == -1: n_trials = len(set(trial_ind)) filtered_spk_idxs_by_trial = np.argwhere( trial_ind <= n_trials ).ravel() filtered_spkts = spkattrs[spike_train_attr_name][ filtered_spk_idxs_by_trial ] filtered_trial_ind = trial_ind[filtered_spk_idxs_by_trial] if time_range is not None: filtered_spk_idxs_by_time = np.argwhere( np.logical_and( filtered_spkts >= time_range[0], filtered_spkts <= time_range[1], ) ).ravel() filtered_spkts = filtered_spkts[filtered_spk_idxs_by_time] filtered_trial_ind = filtered_trial_ind[ filtered_spk_idxs_by_time ] pop_spkindlst.append( np.repeat([spkind], len(filtered_spkts)).astype(np.uint32) ) pop_spktriallst.append(filtered_trial_ind) this_num_cell_spks += len(filtered_spkts) active_set.add(spkind) for i, spkt in enumerate(filtered_spkts): trial_i = filtered_trial_ind[i] if merge_trials: spkt += np.sum(trial_dur[:trial_i]) pop_spktlst.append(spkt) tmin = min(tmin, spkt) tmax = max(tmax, spkt) pop_active_cells[pop_name] = active_set num_cell_spks[pop_name] = this_num_cell_spks if not active_set: continue pop_spkts = np.asarray(pop_spktlst, dtype=np.float32) del pop_spktlst pop_spkinds = np.concatenate(pop_spkindlst, dtype=np.uint32) del pop_spkindlst pop_spktrials = np.concatenate(pop_spktriallst, dtype=np.uint32) del pop_spktriallst # Limit to max_spikes if (max_spikes is not None) and (len(pop_spkts) > max_spikes): logger.warn( f" Reading only randomly sampled {max_spikes} out of {len(pop_spkts)} spikes for population {pop_name}" ) sample_inds = np.random.randint( 0, len(pop_spkinds) - 1, size=int(max_spikes) ) pop_spkts = pop_spkts[sample_inds] pop_spkinds = pop_spkinds[sample_inds] pop_spktrials = pop_spkinds[sample_inds] tmax = max(tmax, max(pop_spkts)) spkpoplst.append(pop_name) pop_trial_spkindlst = [] pop_trial_spktlst = [] for trial_i in range(n_trials): trial_idxs = np.where(pop_spktrials == trial_i)[0] sorted_trial_idxs = np.argsort(pop_spkts[trial_idxs]) pop_trial_spktlst.append( np.take(pop_spkts[trial_idxs], sorted_trial_idxs) ) pop_trial_spkindlst.append( np.take(pop_spkinds[trial_idxs], sorted_trial_idxs) ) del pop_spkts del pop_spkinds del pop_spktrials if merge_trials: pop_spkinds = np.concatenate(pop_trial_spkindlst) pop_spktlst = np.concatenate(pop_trial_spktlst) spkindlst.append(pop_spkinds) spktlst.append(pop_spktlst) else: spkindlst.append(pop_trial_spkindlst) spktlst.append(pop_trial_spktlst) logger.info( f" Read {this_num_cell_spks} spikes and {n_trials} trials for population {pop_name}" ) if tmin == float("inf"): tmin = 0.0 if tmax == float("inf"): tmax = 0.0 return { "spkpoplst": spkpoplst, "spktlst": spktlst, "spkindlst": spkindlst, "tmin": tmin, "tmax": tmax, "pop_active_cells": pop_active_cells, "num_cell_spks": num_cell_spks, "n_trials": n_trials, }
[docs]def make_spike_dict(spkinds, spkts): """ Given arrays with cell indices and spike times, returns a dictionary with per-cell spike times. """ spk_dict = defaultdict(list) for spkind, spkt in zip(np.nditer(spkinds), np.nditer(spkts)): spk_dict[int(spkind)].append(float(spkt)) return spk_dict
[docs]def spike_density_estimate( population, spkdict, time_bins, arena_id=None, trajectory_id=None, output_file_path=None, progress=False, inferred_rate_attr_name="Inferred Rate Map", **kwargs, ): """ Calculates spike density function for the given spike trains. :param population: :param spkdict: :param time_bins: :param arena_id: str :param trajectory_id: str :param output_file_path: :param progress: :param inferred_rate_attr_name: str :param kwargs: dict :return: dict """ if progress: from tqdm import tqdm analysis_options = copy.copy(default_baks_analysis_options) analysis_options.update(kwargs) def make_spktrain(lst, t_start, t_stop): spkts = np.asarray(lst, dtype=np.float32) return spkts[(spkts >= t_start) & (spkts <= t_stop)] t_start = time_bins[0] t_stop = time_bins[-1] spktrains = { ind: make_spktrain(lst, t_start, t_stop) for (ind, lst) in spkdict.items() } baks_args = dict() baks_args["a"] = analysis_options["BAKS Alpha"] baks_args["b"] = analysis_options["BAKS Beta"] if progress: seq = tqdm(spktrains.items()) else: seq = spktrains.items() spk_rate_dict = { ind: baks(spkts / 1000.0, time_bins / 1000.0, **baks_args)[0].reshape( (-1,) ) if len(spkts) > 1 else np.zeros(time_bins.shape) for ind, spkts in seq } if output_file_path is not None: if arena_id is None or trajectory_id is None: raise RuntimeError( "spike_density_estimate: arena_id and trajectory_id required to write Spike Density" "Function namespace" ) namespace = f"Spike Density Function {arena_id} {trajectory_id}" attr_dict = { ind: { inferred_rate_attr_name: np.asarray( spk_rate_dict[ind], dtype="float32" ) } for ind in spk_rate_dict } write_cell_attributes( output_file_path, population, attr_dict, namespace=namespace ) result = { ind: {"rate": rate, "time": time_bins} for ind, rate in spk_rate_dict.items() } result = { ind: {"rate": rate, "time": time_bins} for ind, rate in spk_rate_dict.items() } return result
def spike_bin_counts(spkdict, time_bins): bin_dict = {} for ind, lst in spkdict.items(): if len(lst) > 0: spkts = np.asarray(lst, dtype=np.float32) bins, bin_edges = np.histogram(spkts, bins=time_bins) bin_dict[ind] = bins return bin_dict