from typing import Any, List, Optional, Union, Dict, Tuple
import gc
import pathlib
import os
import subprocess
from collections import defaultdict
from miv_simulator import config
import h5py
import numpy as np
from miv_simulator.utils import (
AbstractEnv,
Iterable,
Struct,
compose_iter,
get_module_logger,
get_trial_time_ranges,
)
from mpi4py import MPI
from neuroh5.io import (
append_cell_attributes,
append_cell_trees,
read_cell_attribute_info,
scatter_read_cell_attribute_selection,
scatter_read_graph_selection,
scatter_read_tree_selection,
write_cell_attributes,
write_graph,
)
def set_union(a, b, datatype):
return a.union(b)
mpi_op_set_union = MPI.Op.Create(set_union, commute=True)
# This logger will inherit its settings from the root logger, created in env
logger = get_module_logger(__name__)
grp_h5types = "H5Types"
grp_projections = "Projections"
grp_populations = "Populations"
path_population_labels = f"/{grp_h5types}/Population labels"
path_population_range = f"/{grp_h5types}/Population range"
grp_population_projections = "Population projections"
grp_valid_population_projections = "Valid population projections"
path_population_projections = f"/{grp_h5types}/Population projections"
# Default I/O configuration
default_io_options = Struct(
io_size=-1,
chunk_size=1000,
value_chunk_size=1000,
cache_size=50,
write_size=10000,
)
def list_concat(
a: List[Union[str, Any]], b: List[Union[str, Any]], datatype: None
) -> List[Union[str, Any]]:
return a + b
mpi_op_concat = MPI.Op.Create(list_concat, commute=True)
def h5_get_group(h, groupname):
if groupname in h:
g = h[groupname]
else:
g = h.create_group(groupname)
return g
def h5_get_dataset(g, dsetname, **kwargs):
if dsetname in g:
dset = g[dsetname]
else:
dset = g.create_dataset(dsetname, (0,), **kwargs)
return dset
def h5_concat_dataset(dset, data):
dsize = dset.shape[0]
newshape = (dsize + len(data),)
dset.resize(newshape)
dset[dsize:] = data
return dset
[docs]def show_celltypes(input_path: str, output=print):
"""
Given the H5 file, output the table of cell types, starting id, and counts..
The CLI command is provided as :option:`'show-h5types' <show-h5types>`.
Parameters
----------
input_path : str
Input h5 file path
output : Callable
output stream (default=print)
"""
with h5py.File(input_path, "r") as h5:
dt_population_labels = h5[path_population_labels]
enum_dtype_dict = h5py.h5t.check_enum_dtype(dt_population_labels.dtype)
population_idx_dict = {enum_dtype_dict[k]: k for k in enum_dtype_dict}
g = h5_get_group(h5, grp_h5types)
populations = h5_get_dataset(g, grp_populations)
output(f"{'Name':<10} {'Start':<8} {'Count':<8}")
output(f"{'====':<10} {'=====':<8} {'=====':<8}")
for i in range(len(populations)):
start, count, idx = populations[i]
name = population_idx_dict[idx]
output(f"{name:<10} {start:<8} {count:<8}")
h5.close()
def import_celltypes(celltype_path, output_path):
import csv
population_dict = {}
with open(celltype_path) as infile:
reader = csv.DictReader(infile, delimiter="\t")
for row in reader:
celltype = row["celltype"]
type_index = int(row["typeIndex"])
range_start = int(row["rangeStart"])
range_end = int(row["rangeEnd"])
count = range_end - range_start + 1
population_dict[celltype] = (type_index, count)
populations = []
for pop_name, pop_info in population_dict.items():
pop_idx = pop_info[0]
pop_count = pop_info[1]
populations.append((pop_name, pop_idx, pop_count))
populations.sort(key=lambda x: x[1])
min_pop_idx = populations[0][1]
# create an HDF5 enumerated type for the population label
mapping = {name: idx for name, idx, count in populations}
dt_population_labels = h5py.special_dtype(enum=(np.uint16, mapping))
with h5py.File(output_path, "x") as h5:
h5[path_population_labels] = dt_population_labels
dt_populations = np.dtype(
[
("Start", np.uint64),
("Count", np.uint32),
("Population", h5[path_population_labels].dtype),
]
)
h5[path_population_range] = dt_populations
# create an HDF5 compound type for population ranges
dt = h5[path_population_range].dtype
g = h5_get_group(h5, grp_h5types)
dset = h5_get_dataset(
g, grp_populations, maxshape=(len(populations),), dtype=dt
)
dset.resize((len(populations),))
a = np.zeros(len(populations), dtype=dt)
start = 0
for name, idx, count in populations:
a[idx - min_pop_idx]["Start"] = start
a[idx - min_pop_idx]["Count"] = count
a[idx - min_pop_idx]["Population"] = idx
start += count
dset[:] = a
h5.close()
return populations
def import_spikeraster(
celltype_path,
spikeraster_path,
output_path,
output_npy=False,
namespace="Spike Data",
progress=False,
comm=None,
):
if progress:
import tqdm
if comm is None:
comm = MPI.COMM_WORLD
populations = import_celltypes(celltype_path, output_path)
n_pop = len(populations)
start = 0
pop_range_bins = []
for name, idx, count in populations[:-1]:
pop_range_bins.append(start + count)
start = start + count
logger.info(
f"populations: {populations} total: {start} pop_range_bins: {pop_range_bins}"
)
logger.info(f"Reading spike data from file {spikeraster_path}...")
if spikeraster_path.endswith(".npy"):
spike_array = np.load(spikeraster_path)
else:
spike_array = np.loadtxt(
spikeraster_path,
dtype=np.dtype([("time", np.float32), ("gid", np.uint32)]),
)
if output_npy:
np.save(f"{spikeraster_path}.npy", spike_array)
logger.info(f"Done reading spike data from file {spikeraster_path}")
gid_array = spike_array["gid"]
gid_bins = np.digitize(gid_array, np.asarray(pop_range_bins))
pop_spk_dict = defaultdict(lambda: defaultdict(list))
if progress:
it = tqdm.tqdm(enumerate(zip(gid_array, gid_bins)), unit_scale=True)
else:
it = enumerate(zip(gid_array, gid_bins))
for i, (gid, pop_idx) in it:
pop_name = populations[pop_idx][0]
pop_start = populations[pop_idx][0]
spk_t = spike_array["time"][i]
pop_spk_dict[pop_name][gid].append(spk_t)
for pop_name, _, _ in populations:
this_spk_dict = pop_spk_dict[pop_name]
logger.info(
f"Saving spike data for population {pop_name} gid set {sorted(this_spk_dict.keys())}"
)
output_dict = {
gid: {"t": np.asarray(spk_ts, dtype=np.float32)}
for gid, spk_ts in this_spk_dict.items()
}
write_cell_attributes(
output_path, pop_name, output_dict, namespace=namespace, comm=comm
)
logger.info(
f"Saved spike data for population {pop_name} to file {output_path}"
)
comm.barrier()
def create_neural_h5(
output_filepath: str,
cell_distributions: config.CellDistributions,
synapses: config.Synapses,
gap_junctions: Optional[Dict] = None,
populations: Optional[Dict[str, config.PopulationsDef]] = None,
) -> None:
if populations is None:
populations = config.PopulationsDef.__members__
_populations = []
for pop_name, pop_idx in populations.items():
layer_counts = cell_distributions[pop_name]
pop_count = 0
for layer_name, layer_count in layer_counts.items():
pop_count += layer_count
_populations.append((pop_name, pop_idx, pop_count))
_populations.sort(key=lambda x: x[1])
min_pop_idx = _populations[0][1]
projections = []
if gap_junctions:
for (post, pre), connection_dict in gap_junctions.items():
projections.append((populations[pre], populations[post]))
else:
for post, connection_dict in synapses.items():
for pre, _ in connection_dict.items():
projections.append((populations[pre], populations[post]))
# create an HDF5 enumerated type for the population label
mapping = {name: idx for name, idx in populations.items()}
dt_population_labels = h5py.special_dtype(enum=(np.uint16, mapping))
with h5py.File(output_filepath, "a") as h5:
h5[path_population_labels] = dt_population_labels
dt_populations = np.dtype(
[
("Start", np.uint64),
("Count", np.uint32),
("Population", h5[path_population_labels].dtype),
]
)
h5[path_population_range] = dt_populations
# create an HDF5 compound type for population ranges
dt = h5[path_population_range].dtype
g = h5_get_group(h5, grp_h5types)
dset = h5_get_dataset(
g, grp_populations, maxshape=(len(_populations),), dtype=dt
)
dset.resize((len(_populations),))
a = np.zeros(len(_populations), dtype=dt)
start = 0
for enum_id, (name, idx, count) in enumerate(_populations):
a[enum_id]["Start"] = start
a[enum_id]["Count"] = count
a[enum_id]["Population"] = idx
start += count
dset[:] = a
dt_projections = np.dtype(
[
("Source", h5[path_population_labels].dtype),
("Destination", h5[path_population_labels].dtype),
]
)
h5[path_population_projections] = dt_projections
dt = h5[path_population_projections]
dset = h5_get_dataset(
g,
grp_valid_population_projections,
maxshape=(len(projections),),
dtype=dt,
)
dset.resize((len(projections),))
a = np.zeros(len(projections), dtype=dt)
idx = 0
for i, prj in enumerate(projections):
src, dst = prj
a[i]["Source"] = int(src)
a[i]["Destination"] = int(dst)
dset[:] = a
h5.close()
[docs]def mkout(env: AbstractEnv, results_filename: str) -> None:
"""
Creates simulation results file and adds H5Types group compatible with NeuroH5.
:param env:
:param results_filename:
:return:
"""
if "Cell Data" in env.model_config:
dataset_path = os.path.join(env.dataset_prefix, env.datasetName)
data_file_path = os.path.join(
dataset_path, env.model_config["Cell Data"]
)
data_file = h5py.File(data_file_path, "r")
results_file = h5py.File(results_filename, "a")
if "H5Types" not in results_file:
data_file.copy("/H5Types", results_file)
data_file.close()
results_file.close()
else:
make_h5types(env, results_filename)
[docs]def spikeout(
env: AbstractEnv,
output_path: str,
t_start: Optional[float] = None,
clear_data: bool = False,
) -> None:
"""
Writes spike times to specified NeuroH5 output file.
:param env:
:param output_path:
:param clear_data:
:return:
"""
equilibration_duration = float(
env.stimulus_config["Equilibration Duration"]
)
n_trials = env.n_trials
t_vec = env.t_vec.as_numpy()
id_vec = np.asarray(env.id_vec.as_numpy(), dtype=np.uint32)
trial_time_ranges = get_trial_time_ranges(
env.t_rec.to_python(), env.n_trials
)
trial_time_bins = [
t_trial_start for t_trial_start, t_trial_end in trial_time_ranges
]
trial_dur = np.asarray(
[env.tstop + equilibration_duration] * n_trials, dtype=np.float32
)
binlst = []
typelst = sorted(env.celltypes.keys())
binvect = np.asarray([env.celltypes[k]["start"] for k in typelst])
sort_idx = np.argsort(binvect, axis=0)
pop_names = [typelst[i] for i in sort_idx]
bins = binvect[sort_idx][1:]
inds = np.digitize(id_vec, bins)
if env.results_namespace_id is None:
namespace_id = "Spike Events"
else:
namespace_id = f"Spike Events {str(env.results_namespace_id)}"
for i, pop_name in enumerate(pop_names):
spkdict = {}
sinds = np.where(inds == i)
if len(sinds) > 0:
ids = id_vec[sinds]
ts = t_vec[sinds]
for j in range(0, len(ids)):
gid = ids[j]
t = ts[j]
if (t_start is None) or (t >= t_start):
if gid in spkdict:
spkdict[gid]["t"].append(t)
else:
spkdict[gid] = {"t": [t]}
for gid in spkdict:
is_artificial = gid in env.artificial_cells[pop_name]
spiketrain = np.array(spkdict[gid]["t"], dtype=np.float32)
if gid in env.spike_onset_delay:
spiketrain -= env.spike_onset_delay[gid]
trial_bins = np.digitize(spiketrain, trial_time_bins) - 1
trial_spikes = [
np.copy(spiketrain[np.where(trial_bins == trial_i)[0]])
for trial_i in range(n_trials)
]
for trial_i, trial_spiketrain in enumerate(trial_spikes):
trial_spiketrain = trial_spikes[trial_i]
trial_spiketrain -= (
np.sum(trial_dur[:(trial_i)]) + equilibration_duration
)
spkdict[gid]["t"] = np.concatenate(trial_spikes)
spkdict[gid]["Trial Duration"] = trial_dur
spkdict[gid]["Trial Index"] = np.asarray(
trial_bins, dtype=np.uint8
)
spkdict[gid]["artificial"] = np.asarray(
[1 if is_artificial else 0], dtype=np.uint8
)
append_cell_attributes(
output_path,
pop_name,
spkdict,
namespace=namespace_id,
comm=env.comm,
io_size=env.io_size,
)
del spkdict
if clear_data:
env.t_vec.resize(0)
env.id_vec.resize(0)
env.comm.barrier()
if env.comm.Get_rank() == 0:
logger.info(f"*** Output spike results to file {output_path}")
[docs]def recsout(
env: AbstractEnv,
output_path: str,
t_start: Optional[float] = None,
clear_data: bool = False,
write_cell_location_data: bool = False,
write_trial_data: bool = False,
) -> None:
"""
Writes intracellular state traces to specified NeuroH5 output file.
:param env:
:param output_path:
:param clear_data:
:param reduce_data:
:return:
"""
t_rec = env.t_rec
equilibration_duration = float(
env.stimulus_config["Equilibration Duration"]
)
reduce_data = env.recording_profile.get("reduce", None)
n_trials = env.n_trials
trial_time_ranges = get_trial_time_ranges(
env.t_rec.to_python(), env.n_trials
)
trial_time_bins = [
t_trial_start for t_trial_start, t_trial_end in trial_time_ranges
]
trial_dur = np.asarray(
[env.tstop + equilibration_duration] * n_trials, dtype=np.float32
)
for pop_name in sorted(env.celltypes.keys()):
local_rec_types = list(env.recs_dict[pop_name].keys())
rec_types = sorted(
set(env.comm.allreduce(local_rec_types, op=mpi_op_concat))
)
for rec_type in rec_types:
recs = env.recs_dict[pop_name][rec_type]
attr_dict = defaultdict(lambda: {})
for rec in recs:
gid = rec["gid"]
data_vec = np.array(
rec["vec"], copy=clear_data, dtype=np.float32
)
time_vec = np.array(t_rec, copy=clear_data, dtype=np.float32)
if t_start is not None:
time_inds = np.where(time_vec >= t_start)[0]
time_vec = time_vec[time_inds]
data_vec = data_vec[time_inds]
trial_bins = np.digitize(time_vec, trial_time_bins) - 1
for trial_i in range(n_trials):
trial_inds = np.where(trial_bins == trial_i)[0]
time_vec[trial_inds] -= (
np.sum(trial_dur[:(trial_i)]) + equilibration_duration
)
label = rec["label"]
if label in attr_dict[gid]:
if reduce_data is None:
raise RuntimeError(
"recsout: duplicate recorder labels and no reduce strategy specified"
)
elif reduce_data is True:
attr_dict[gid][label] += data_vec
else:
raise RuntimeError(
"recsout: unsupported reduce strategy specified"
)
else:
attr_dict[gid][label] = data_vec
attr_dict[gid]["t"] = time_vec
if write_trial_data:
attr_dict[gid]["trial duration"] = trial_dur
if write_cell_location_data:
distance = rec.get("distance", None)
if distance is not None:
attr_dict[gid]["distance"] = np.asarray(
[distance], dtype=np.float32
)
section = rec.get("section", None)
if section is not None:
attr_dict[gid]["section"] = np.asarray(
[section], dtype=np.int16
)
loc = rec.get("loc", None)
if loc is not None:
attr_dict[gid]["loc"] = np.asarray(
[loc], dtype=np.float32
)
if clear_data:
rec["vec"].resize(0)
if env.results_namespace_id is None:
namespace_id = f"Intracellular {rec_type}"
else:
namespace_id = "Intracellular {} {}".format(
rec_type, str(env.results_namespace_id)
)
append_cell_attributes(
output_path,
pop_name,
attr_dict,
namespace=namespace_id,
comm=env.comm,
io_size=env.io_size,
)
if clear_data:
env.t_rec.resize(0)
env.comm.barrier()
if env.comm.Get_rank() == 0:
logger.info(
f"*** Output intracellular state results to file {output_path}"
)
[docs]def lfpout(env: AbstractEnv, output_path: str):
"""
Writes local field potential voltage traces to specified HDF5 output file.
:param env:
:param output_path:
:param clear_data:
:return:
"""
for lfp in list(env.lfp.values()):
if env.results_namespace_id is None:
namespace_id = f"Local Field Potential {str(lfp.label)}"
else:
namespace_id = "Local Field Potential {} {}".format(
str(lfp.label), str(env.results_namespace_id)
)
import h5py
output = h5py.File(output_path, "a")
grp = output.create_group(namespace_id)
grp["t"] = np.asarray(lfp.t, dtype=np.float32)
grp["v"] = np.asarray(lfp.meanlfp, dtype=np.float32)
output.close()
if env.comm.Get_rank() == 0:
logger.info(f"*** Output LFP results to file {output_path}")
[docs]def get_h5py_attr(attrs, key):
"""
str values are stored as bytes in h5py container attrs dictionaries. This function enables py2/py3 compatibility by
always returning them to str type upon read. Values should be converted during write with the companion function
set_h5py_str_attr.
:param attrs: :class:'h5py._hl.attrs.AttributeManager'
:param key: str
:return: val with type converted if str or array of str
"""
if key not in attrs:
raise KeyError(f"get_h5py_attr: invalid key: {key}")
val = attrs[key]
if isinstance(val, str):
val = np.string_(val).astype(str)
elif isinstance(val, Iterable) and len(val) > 0:
if isinstance(val[0], str):
val = np.array(val, dtype="str")
return val
[docs]def set_h5py_attr(attrs, key, val):
"""
str values are stored as bytes in h5py container attrs dictionaries. This function enables py2/py3 compatibility by
always converting them to `np.string_` upon write. Values should be converted back to str during read with the
companion function get_h5py_str_attr.
:param attrs: :class:'h5py._hl.attrs.AttributeManager'
:param key: str
:param val: type converted if str or array of str
"""
if isinstance(val, str):
val = np.string_(val)
elif isinstance(val, Iterable) and len(val) > 0:
if isinstance(val[0], str):
val = np.array(val, dtype="S")
attrs[key] = val
[docs]def get_h5py_group(file, hierarchy, create=False):
"""
:param file: :class: in ['h5py.File', 'h5py.Group']
:param hierarchy: list of str
:param create: bool
:return: :class:'h5py.Group'
"""
target = file
for key in hierarchy:
if key is not None:
key = str(key)
if key not in target:
if create:
target = target.create_group(key)
else:
raise KeyError(
"get_h5py_group: target: %s does not contain key: %s; valid keys: %s"
% (target, key, list(target.keys()))
)
else:
target = target[key]
return target
[docs]def write_cell_selection(
env: AbstractEnv, write_selection_file_path, populations=None, write_kwds={}
):
"""
Writes out the data necessary to instantiate the selected cells.
:param env: an instance of the `Env` class
"""
if "comm" not in write_kwds:
write_kwds["comm"] = env.comm
if "io_size" not in write_kwds:
write_kwds["io_size"] = env.io_size
rank = int(env.comm.Get_rank())
nhosts = int(env.comm.Get_size())
dataset_path = env.dataset_path
data_file_path = env.data_file_path
if populations is None:
pop_names = sorted(env.cell_selection.keys())
else:
pop_names = populations
for pop_name in pop_names:
gid_range = [
gid
for i, gid in enumerate(env.cell_selection[pop_name])
if i % nhosts == rank
]
trees_output_dict = {}
coords_output_dict = {}
num_cells = 0
if (pop_name in env.cell_attribute_info) and (
"Trees" in env.cell_attribute_info[pop_name]
):
if rank == 0:
logger.info(f"*** Reading trees for population {pop_name}")
cell_tree_iter, _ = scatter_read_tree_selection(
data_file_path,
pop_name,
selection=gid_range,
topology=False,
comm=env.comm,
io_size=env.io_size,
)
if rank == 0:
logger.info(f"*** Done reading trees for population {pop_name}")
for i, (gid, tree) in enumerate(cell_tree_iter):
trees_output_dict[gid] = tree
num_cells += 1
assert len(trees_output_dict) == len(gid_range)
elif (pop_name in env.cell_attribute_info) and (
"Coordinates" in env.cell_attribute_info[pop_name]
):
if rank == 0:
logger.info(
f"*** Reading coordinates for population {pop_name}"
)
cell_attributes_iter = scatter_read_cell_attribute_selection(
data_file_path,
pop_name,
selection=gid_range,
namespace="Coordinates",
comm=env.comm,
io_size=env.io_size,
)
if rank == 0:
logger.info(
f"*** Done reading coordinates for population {pop_name}"
)
for i, (gid, coords) in enumerate(cell_attributes_iter):
coords_output_dict[gid] = coords
num_cells += 1
if rank == 0:
logger.info(
f"*** Writing cell selection for population {pop_name} to file {write_selection_file_path}"
)
append_cell_trees(
write_selection_file_path, pop_name, trees_output_dict, **write_kwds
)
write_cell_attributes(
write_selection_file_path,
pop_name,
coords_output_dict,
namespace="Coordinates",
**write_kwds,
)
env.comm.barrier()
[docs]def write_connection_selection(
env: AbstractEnv, write_selection_file_path, populations=None, write_kwds={}
):
"""
Loads NeuroH5 connectivity file, and writes the corresponding
synapse and network connection mechanisms for the selected postsynaptic cells.
:param env: an instance of the `Env` class
"""
if "comm" not in write_kwds:
write_kwds["comm"] = env.comm
if "io_size" not in write_kwds:
write_kwds["io_size"] = env.io_size
connectivity_file_path = env.connectivity_file_path
forest_file_path = env.forest_file_path
rank = int(env.comm.Get_rank())
nhosts = int(env.comm.Get_size())
syn_attrs = env.synapse_attributes
if populations is None:
pop_names = sorted(env.cell_selection.keys())
else:
pop_names = populations
input_sources = {pop_name: set() for pop_name in env.celltypes}
for postsyn_name, presyn_names in sorted(env.projection_dict.items()):
gc.collect()
if rank == 0:
logger.info(
f"*** Writing connection selection of population {postsyn_name}"
)
if postsyn_name not in pop_names:
continue
gid_range = [
gid
for i, gid in enumerate(env.cell_selection[postsyn_name])
if i % nhosts == rank
]
synapse_config = env.celltypes[postsyn_name]["synapses"]
weight_dicts = []
has_weights = False
if "weights" in synapse_config:
has_weights = True
weight_dicts = synapse_config["weights"]
if rank == 0:
logger.info(
f"*** Reading synaptic attributes for population {postsyn_name}"
)
syn_attributes_iter = scatter_read_cell_attribute_selection(
forest_file_path,
postsyn_name,
selection=gid_range,
namespace="Synapse Attributes",
comm=env.comm,
io_size=env.io_size,
)
syn_attributes_output_dict = dict(list(syn_attributes_iter))
write_cell_attributes(
write_selection_file_path,
postsyn_name,
syn_attributes_output_dict,
namespace="Synapse Attributes",
**write_kwds,
)
del syn_attributes_output_dict
del syn_attributes_iter
if has_weights:
for weight_dict in weight_dicts:
weights_namespaces = weight_dict["namespace"]
if rank == 0:
logger.info(
"*** Reading synaptic weights of population {} from namespaces {}".format(
postsyn_name, str(weights_namespaces)
)
)
for weights_namespace in weights_namespaces:
syn_weights_iter = scatter_read_cell_attribute_selection(
forest_file_path,
postsyn_name,
namespace=weights_namespace,
selection=gid_range,
comm=env.comm,
io_size=env.io_size,
)
weight_attributes_output_dict = dict(list(syn_weights_iter))
write_cell_attributes(
write_selection_file_path,
postsyn_name,
weight_attributes_output_dict,
namespace=weights_namespace,
**write_kwds,
)
del weight_attributes_output_dict
del syn_weights_iter
logger.info(
"*** Rank %i: reading connectivity selection from file %s for postsynaptic population: %s: selection: %s"
% (rank, connectivity_file_path, postsyn_name, str(gid_range))
)
(graph, attr_info) = scatter_read_graph_selection(
connectivity_file_path,
selection=gid_range,
projections=[
(presyn_name, postsyn_name)
for presyn_name in sorted(presyn_names)
],
comm=env.comm,
io_size=env.io_size,
namespaces=["Synapses", "Connections"],
)
for presyn_name in sorted(presyn_names):
gid_dict = {}
edge_count = 0
node_count = 0
if postsyn_name in graph:
if (
postsyn_name in attr_info
and presyn_name in attr_info[postsyn_name]
):
edge_attr_info = attr_info[postsyn_name][presyn_name]
else:
raise RuntimeError(
"write_connection_selection: missing edge attributes for projection %s -> %s"
% (presyn_name, postsyn_name)
)
if (
"Synapses" in edge_attr_info
and "syn_id" in edge_attr_info["Synapses"]
and "Connections" in edge_attr_info
and "distance" in edge_attr_info["Connections"]
):
syn_id_attr_index = edge_attr_info["Synapses"]["syn_id"]
distance_attr_index = edge_attr_info["Connections"][
"distance"
]
else:
raise RuntimeError(
"write_connection_selection: missing edge attributes for projection %s -> %s"
% (presyn_name, postsyn_name)
)
edge_iter = compose_iter(
lambda edgeset: input_sources[presyn_name].update(
edgeset[1][0]
),
graph[postsyn_name][presyn_name],
)
for postsyn_gid, edges in edge_iter:
presyn_gids, edge_attrs = edges
edge_syn_ids = edge_attrs["Synapses"][syn_id_attr_index]
edge_dists = edge_attrs["Connections"][distance_attr_index]
gid_dict[postsyn_gid] = (
presyn_gids,
{
"Synapses": {"syn_id": edge_syn_ids},
"Connections": {"distance": edge_dists},
},
)
edge_count += len(presyn_gids)
node_count += 1
env.comm.barrier()
logger.info(
"*** Rank %d: Writing projection %s -> %s selection: %d nodes, %d edges"
% (rank, presyn_name, postsyn_name, node_count, edge_count)
)
write_graph(
write_selection_file_path,
src_pop_name=presyn_name,
dst_pop_name=postsyn_name,
edges=gid_dict,
comm=env.comm,
io_size=env.io_size,
)
env.comm.barrier()
return input_sources
def query_cell_attributes(input_file, population_names, namespace_ids=None):
pop_state_dict = {}
logger.info("Querying cell attribute data...")
attr_info_dict = read_cell_attribute_info(
input_file, populations=population_names, read_cell_index=True
)
namespace_id_lst = []
for pop_name in attr_info_dict:
cell_index = None
pop_state_dict[pop_name] = {}
if namespace_ids is None:
namespace_id_lst = attr_info_dict[pop_name].keys()
else:
namespace_id_lst = namespace_ids
return namespace_id_lst, attr_info_dict
def _run(commands):
cmd = " ".join(commands)
print(cmd)
subprocess.check_output(commands)
def copy_dataset(f_src: h5py.File, f_dst: h5py.File, dset_path: str) -> None:
print(f"Copying {dset_path} from {f_src} to {f_dst} ...")
target_path = str(pathlib.Path(dset_path).parent)
f_src.copy(f_src[dset_path], f_dst[target_path])
[docs]class H5FileManager:
"""Utility to manage NeuroH5 simulator files"""
def __init__(self, directory: str):
self.directory = directory
def local_directory(self, *append: str, create: bool = False) -> str:
d = os.path.join(os.path.abspath(self.directory), *append)
if create:
os.makedirs(d, exist_ok=True)
return d
@property
def cells_filepath(self) -> str:
return self.local_directory("cells.h5")
@property
def connections_filepath(self) -> str:
return self.local_directory("connections.h5")
def import_h5types(self, src: str):
with h5py.File(self.cells_filepath, "w") as f:
input_file = h5py.File(src, "r")
copy_dataset(input_file, f, "/H5Types")
input_file.close()
with h5py.File(self.connections_filepath, "w") as f:
input_file = h5py.File(src, "r")
copy_dataset(input_file, f, "/H5Types")
input_file.close()
def import_soma_coordinates(self, src: str, populations: Tuple[str] = ()):
with h5py.File(self.cells_filepath, "a") as f_dst:
grp = f_dst.create_group("Populations")
for p in populations:
grp.create_group(p)
for p in populations:
coords_dset_path = f"/Populations/{p}/Generated Coordinates"
coords_output_path = f"/Populations/{p}/Coordinates"
distances_dset_path = f"/Populations/{p}/Arc Distances"
with h5py.File(src, "r") as f_src:
copy_dataset(f_src, f_dst, coords_dset_path)
copy_dataset(f_src, f_dst, distances_dset_path)
def import_synapse_attributes(
self, population: str, forest_file: str, synapses_file: str
):
forest_dset_path = f"/Populations/{population}/Trees"
forest_syns_dset_path = f"/Populations/{population}/Synapse Attributes"
cmd = [
"h5copy",
"-p",
"-s",
forest_dset_path,
"-d",
forest_dset_path,
"-i",
forest_file,
"-o",
self.cells_filepath,
]
_run(cmd)
cmd = [
"h5copy",
"-p",
"-s",
forest_syns_dset_path,
"-d",
forest_syns_dset_path,
"-i",
synapses_file,
"-o",
self.cells_filepath,
]
_run(cmd)
def import_projections(self, population: str, src: str):
projection_dset_path = f"/Projections/{population}"
cmd = [
"h5copy",
"-p",
"-s",
projection_dset_path,
"-d",
projection_dset_path,
"-i",
src,
"-o",
self.connections_filepath,
]
_run(cmd)
def copy_stim_coordinates(self):
cmd = [
"h5copy",
"-p",
"-s",
"/Populations/STIM/Generated Coordinates",
"-d",
"/Populations/STIM/Coordinates",
"-i",
self.cells_filepath,
"-o",
self.cells_filepath,
]
_run(cmd)