from typing import Dict, Optional, Union
import logging
import os
from collections import defaultdict, namedtuple
import numpy as np
import yaml
from miv_simulator.synapses import SynapseAttributes, get_syn_filter_dict
from miv_simulator.utils import (
AbstractEnv,
ExprClosure,
IncludeLoader,
get_root_logger,
read_from_yaml,
)
from mpi4py import MPI
from mpi4py.MPI import Intracomm
from neuroh5.io import (
read_cell_attribute_info,
read_population_names,
read_population_ranges,
read_projection_names,
)
SynapseConfig = namedtuple(
"SynapseConfig",
["type", "sections", "layers", "proportions", "contacts", "mechanisms"],
)
GapjunctionConfig = namedtuple(
"GapjunctionConfig",
[
"sections",
"connection_probability",
"connection_parameters",
"connection_bounds",
"coupling_coefficients",
"coupling_parameters",
"coupling_bounds",
],
)
NetclampConfig = namedtuple(
"NetclampConfig",
["template_params", "weight_generators", "optimize_parameters"],
)
ArenaConfig = namedtuple(
"Arena", ["name", "domain", "trajectories", "properties"]
)
DomainConfig = namedtuple("Domain", ["vertices", "simplices"])
StimulusConfig = namedtuple("Stimulus", ["velocity", "path"])
[docs]class Env(AbstractEnv):
"""
Network model configuration.
"""
def __init__(
self,
comm: Optional[Intracomm] = None,
config: Optional[str] = None,
template_paths: str = "templates",
hoc_lib_path: Optional[str] = None,
dataset_prefix: Optional[str] = None,
results_path: Optional[str] = None,
results_file_id: Optional[str] = None,
results_namespace_id: None = None,
node_rank_file: None = None,
node_allocation: None = None,
io_size: int = 0,
use_cell_attr_gen: bool = False,
cell_attr_gen_cache_size: int = 10,
recording_profile: Optional[str] = None,
tstart: float = 0.0,
tstop: Union[int, float] = 0.0,
v_init: Union[int, float] = -65,
stimulus_onset: float = 0.0,
n_trials: int = 1,
max_walltime_hours: float = 0.5,
checkpoint_interval: float = 500.0,
checkpoint_clear_data: bool = True,
nrn_timeout: float = 600.0,
results_write_time: Union[int, float] = 0,
dt: Optional[float] = None,
ldbal: bool = False,
lptbal: bool = False,
cell_selection_path: None = None,
microcircuit_inputs: bool = False,
spike_input_path: None = None,
spike_input_namespace: None = None,
spike_input_attr: None = None,
cleanup: bool = True,
cache_queries: bool = False,
profile_memory: bool = False,
use_coreneuron: bool = False,
transfer_debug: bool = False,
verbose: bool = False,
config_prefix="",
**kwargs,
) -> None:
"""
:param comm: :class:'MPI.COMM_WORLD'
:param config_file: str; model configuration file name
:param template_paths: str; colon-separated list of paths to directories containing hoc cell templates
:param hoc_lib_path: str; path to directory containing required hoc libraries
:param dataset_prefix: str; path to directory containing required neuroh5 data files
:param results_path: str; path to directory to export output files
:param results_file_id: str; label for neuroh5 files to write spike and voltage trace data
:param results_namespace_id: str; label for neuroh5 namespaces to write spike and voltage trace data
:param node_rank_file: str; name of file specifying assignment of node gids to MPI ranks
:param node_allocation: iterable; gids assigned to the current MPI ranks; cannot be specified together with node_rank_file
:param io_size: int; the number of MPI ranks to be used for I/O operations
:param recording_profile: str; intracellular recording configuration to use
:param tstart: float; start of physical time to simulate (ms)
:param tstop: int; physical time to simulate (ms)
:param v_init: float; initialization membrane potential (mV)
:param stimulus_onset: float; starting time of stimulus (ms)
:param max_walltime_hours: float; maximum wall time (hours)
:param results_write_time: float; time to write out results at end of simulation
:param dt: float; simulation time step
:param ldbal: bool; estimate load balance based on cell complexity
:param lptbal: bool; calculate load balance with LPT algorithm
:param cleanup: bool; clean up auxiliary cell and synapse structures after network init
:param profile: bool; profile memory usage
:param cache_queries: bool; whether to use a cache to speed up queries to filter_synapses
:param verbose: bool; print verbose diagnostic messages while constructing the network
"""
self.kwargs = kwargs
self.SWC_Types = {}
self.SWC_Type_index = {}
self.Synapse_Types = {}
self.Synapse_Type_index = {}
self.layers = {}
self.layer_type_index = {}
self.globals = {}
self.gidset = set()
self.gjlist = []
self.cells = defaultdict(lambda: dict())
self.artificial_cells = defaultdict(lambda: dict())
self.biophys_cells = defaultdict(lambda: dict())
self.spike_onset_delay = {}
self.recording_sets = {}
self.pc = None
if comm is None:
self.comm = MPI.COMM_WORLD
else:
self.comm = comm
rank = self.comm.Get_rank()
if rank == 0:
color = 1
else:
color = 0
## comm0 includes only rank 0
comm0 = self.comm.Split(color, 0)
self.use_coreneuron = use_coreneuron
# If true, the biophysical cells and synapses dictionary will be freed
# as synapses and connections are instantiated.
self.cleanup = cleanup
# If true, compute and print memory usage at various points
# during simulation initialization
self.profile_memory = profile_memory
# print verbose diagnostic messages
self.verbose = verbose
self.logger = get_root_logger()
if self.verbose:
self.logger.setLevel(logging.INFO)
# Directories for cell templates
if template_paths is not None:
self.template_paths = template_paths.split(":")
else:
self.template_paths = []
self.template_dict = {}
# The location of required hoc libraries
self.hoc_lib_path = hoc_lib_path
# Checkpoint interval in ms of simulation time
self.checkpoint_clear_data = checkpoint_clear_data
self.last_checkpoint = 0.0
if checkpoint_interval > 0.0:
self.checkpoint_interval = max(float(checkpoint_interval), 1.0)
else:
self.checkpoint_interval = None
# NEURON timeout value (0 if None)
self.nrn_timeout = int(nrn_timeout) if nrn_timeout is not None else 0
# The location of all datasets
self.dataset_prefix = dataset_prefix
# The path where results files should be written
self.results_path = results_path
# Identifier used to construct results data namespaces
self.results_namespace_id = results_namespace_id
# Identifier used to construct results data files
self.results_file_id = results_file_id
# Number of MPI ranks to be used for I/O operations
self.io_size = int(io_size)
# Whether to use cell attribute generation for I/O operations
# and number of cache (readahead) items
self.use_cell_attr_gen = use_cell_attr_gen
self.cell_attr_gen_cache_size = cell_attr_gen_cache_size
# Initialization voltage
self.v_init = float(v_init)
# simulation time [ms]
self.tstart = float(tstart)
self.tstop = float(tstop)
# stimulus onset time [ms]
self.stimulus_onset = float(stimulus_onset)
# number of trials
self.n_trials = int(n_trials)
# maximum wall time in hours
self.max_walltime_hours = float(max_walltime_hours)
# time to write out results at end of simulation
self.results_write_time = float(results_write_time)
# time step
self.dt = float(dt if dt is not None else 0.025)
# used to estimate cell complexity
self.cxvec = None
# measure/perform load balancing
self.optldbal = ldbal
self.optlptbal = lptbal
self.transfer_debug = transfer_debug
# cache queries to filter_synapses
self.cache_queries = cache_queries
self.model_config = None
self.config_prefix = config_prefix
if rank == 0:
if isinstance(config, str):
# load complete configuration from file
p = config
if config_prefix != "" and not os.path.isabs(config):
p = os.path.join(config_prefix, config)
with open(p) as fp:
self.model_config = yaml.load(fp, IncludeLoader)
else:
self.model_config = config
self.model_config = self.comm.bcast(self.model_config, root=0)
if "Definitions" in self.model_config:
self.parse_definitions()
self.SWC_Type_index = {
item[1]: item[0] for item in self.SWC_Types.items()
}
self.Synapse_Type_index = {
item[1]: item[0] for item in self.Synapse_Types.items()
}
self.layer_type_index = {
item[1]: item[0] for item in self.layers.items()
}
if "Global Parameters" in self.model_config:
self.parse_globals()
self.geometry = None
if "Geometry" in self.model_config:
self.geometry = self.model_config["Geometry"]
if "Origin" in self.geometry["Parametric Surface"]:
self.parse_origin_coords()
self.celltypes = self.model_config["Cell Types"]
self.cell_attribute_info = {}
# The name of this model
self.modelName = "Unnamed model"
if "Model Name" in self.model_config:
self.modelName = self.model_config["Model Name"]
# The dataset to use for constructing the network
if "Dataset Name" in self.model_config:
self.datasetName = self.model_config["Dataset Name"]
if rank == 0:
self.logger.info(f"env.dataset_prefix = {str(self.dataset_prefix)}")
# Cell selection for simulations of subsets of the network
self.cell_selection = None
self.cell_selection_path = cell_selection_path
if rank == 0:
self.logger.info(
f"env.cell_selection_path = {str(self.cell_selection_path)}"
)
if cell_selection_path is not None:
with open(cell_selection_path) as fp:
self.cell_selection = yaml.load(fp, IncludeLoader)
self.cell_selection = self.comm.bcast(self.cell_selection, root=0)
# Spike input path
self.spike_input_path = spike_input_path
self.spike_input_ns = spike_input_namespace
self.spike_input_attr = spike_input_attr
self.spike_input_attribute_info = None
if self.spike_input_path is not None:
if rank == 0:
self.logger.info(
f"env.spike_input_path = {str(self.spike_input_path)}"
)
self.spike_input_attribute_info = read_cell_attribute_info(
self.spike_input_path,
sorted(self.Populations.keys()),
comm=comm0,
)
self.logger.info(
"env.spike_input_attribute_info = %s"
% str(self.spike_input_attribute_info)
)
self.spike_input_attribute_info = self.comm.bcast(
self.spike_input_attribute_info, root=0
)
if results_path:
if self.results_file_id is None:
self.results_file_path = (
f"{self.results_path}/{self.modelName}_results.h5"
)
else:
self.results_file_path = f"{self.results_path}/{self.modelName}_results_{self.results_file_id}.h5"
else:
if self.results_file_id is None:
self.results_file_path = f"{self.modelName}_results.h5"
else:
self.results_file_path = (
f"{self.modelName}_results_{self.results_file_id}.h5"
)
if "Connection Generator" in self.model_config:
self.parse_connection_config()
self.parse_gapjunction_config()
if self.dataset_prefix is not None:
self.dataset_path = os.path.join(
self.dataset_prefix, self.datasetName
)
if "Cell Data" in self.model_config:
self.data_file_path = os.path.join(
self.dataset_path, self.model_config["Cell Data"]
)
self.forest_file_path = os.path.join(
self.dataset_path, self.model_config["Cell Data"]
)
self.load_celltypes()
else:
self.data_file_path = None
self.forest_file_path = None
if rank == 0:
self.logger.info(f"env.data_file_path = {self.data_file_path}")
if "Connection Data" in self.model_config:
self.connectivity_file_path = os.path.join(
self.dataset_path, self.model_config["Connection Data"]
)
else:
self.connectivity_file_path = None
if "Gap Junction Data" in self.model_config:
self.gapjunctions_file_path = os.path.join(
self.dataset_path, self.model_config["Gap Junction Data"]
)
else:
self.gapjunctions_file_path = None
else:
self.dataset_path = None
self.data_file_path = None
self.connectivity_file_path = None
self.forest_file_path = None
self.gapjunctions_file_path = None
self.node_allocation = None
if node_rank_file and node_allocation:
raise RuntimeError(
"Only one of node_rank_file and node_allocation must be specified."
)
if node_rank_file:
self.load_node_rank_map(node_rank_file)
if node_allocation:
self.node_allocation = set(node_allocation)
self.netclamp_config = None
if "Network Clamp" in self.model_config:
self.parse_netclamp_config()
self.stimulus_config = None
self.arena_id = None
self.stimulus_id = None
if "Stimulus" in self.model_config:
self.parse_stimulus_config()
self.init_stimulus_config(**kwargs)
self.analysis_config = None
if "Analysis" in self.model_config:
self.analysis_config = self.model_config["Analysis"]
self.projection_dict = None
if self.dataset_prefix is not None:
if rank == 0:
projection_dict = defaultdict(list)
self.logger.info(
f"env.connectivity_file_path = {str(self.connectivity_file_path)}"
)
if self.connectivity_file_path is not None:
for src, dst in read_projection_names(
self.connectivity_file_path, comm=comm0
):
projection_dict[dst].append(src)
self.projection_dict = dict(projection_dict)
self.logger.info(
f"projection_dict = {str(self.projection_dict)}"
)
self.projection_dict = self.comm.bcast(self.projection_dict, root=0)
# If True, instantiate as spike source those cells that do not
# have data in the input data file
self.microcircuit_inputs = microcircuit_inputs or (
self.cell_selection is not None
)
self.microcircuit_input_sources = {
pop_name: set() for pop_name in self.celltypes.keys()
}
# Configuration profile for optogenetic stimulation
self.opsin_config = None
if "Stimulus" in self.model_config:
if "Opsin" in self.model_config["Stimulus"]:
config = self.model_config["Stimulus"]["Opsin"]
self.opsin_config = {
"nstates": int(config["nstates"]),
"opsin type": config["opsin type"],
"protocol": config["protocol"],
"protocol parameters": config.get(
"protocol parameters", dict()
),
"rho parameters": config.get("rho parameters", dict()),
}
# Configuration profile for recording intracellular quantities
self.recording_profile = None
if ("Recording" in self.model_config) and (
recording_profile is not None
):
self.recording_profile = self.model_config["Recording"][
"Intracellular"
][recording_profile]
self.recording_profile["label"] = recording_profile
for recvar, recdict in self.recording_profile.get(
"synaptic quantity", {}
).items():
filters = {}
if "syn types" in recdict:
filters["syn_types"] = recdict["syn types"]
if "swc types" in recdict:
filters["swc_types"] = recdict["swc types"]
if "layers" in recdict:
filters["layers"] = recdict["layers"]
if "sources" in recdict:
filters["sources"] = recdict["sources"]
syn_filters = get_syn_filter_dict(self, filters, convert=True)
recdict["syn_filters"] = syn_filters
if self.use_coreneuron:
self.recording_profile["dt"] = None
# Configuration profile for recording local field potentials
self.LFP_config = {}
if "Recording" in self.model_config:
for label, config in self.model_config["Recording"]["LFP"].items():
self.LFP_config[label] = {
"position": tuple(config["position"]),
"maxEDist": config["maxEDist"],
"fraction": config["fraction"],
"rho": config["rho"],
"dt": config["dt"],
}
self.t_vec = None
self.id_vec = None
self.t_rec = None
self.recs_dict = {} # Intracellular samples on this host
self.recs_count = 0
self.recs_pps_set = set()
for pop_name, _ in self.Populations.items():
self.recs_dict[pop_name] = defaultdict(list)
# used to calculate model construction times and run time
self.mkcellstime = 0
self.mkstimtime = 0
self.connectcellstime = 0
self.connectgjstime = 0
self.simtime = None
self.lfp = {}
self.edge_count = defaultdict(dict)
self.syns_set = defaultdict(set)
comm0.Free()
def parse_arena_domain(self, config):
vertices = config["vertices"]
simplices = config["simplices"]
return DomainConfig(vertices, simplices)
def parse_arena_trajectory(self, config):
velocity = float(config["run velocity"])
path_config = config["path"]
path_x = []
path_y = []
for v in path_config:
path_x.append(v[0])
path_y.append(v[1])
path = np.column_stack(
(
np.asarray(path_x, dtype=np.float32),
np.asarray(path_y, dtype=np.float32),
)
)
return StimulusConfig(velocity, path)
def init_stimulus_config(
self,
arena_id: Optional[str] = None,
stimulus_id: Optional[str] = None,
**kwargs,
) -> None:
if arena_id is not None:
if arena_id in self.stimulus_config["Arena"]:
self.arena_id = arena_id
else:
raise RuntimeError(
"init_stimulus_config: arena id parameter not found in stimulus configuration"
)
if stimulus_id is None:
self.stimulus_id = None
else:
if (
stimulus_id
in self.stimulus_config["Arena"][arena_id].trajectories
):
self.stimulus_id = stimulus_id
else:
raise RuntimeError(
"init_stimulus_config: stimulus id parameter not found in stimulus configuration"
)
def parse_stimulus_config(self) -> None:
stimulus_dict = self.model_config["Stimulus"]
stimulus_config = {}
for k, v in stimulus_dict.items():
if k == "Selectivity Type Probabilities":
selectivity_type_prob_dict = {}
for pop, dvals in v.items():
pop_selectivity_type_prob_dict = {}
for (
selectivity_type_name,
selectivity_type_prob,
) in dvals.items():
pop_selectivity_type_prob_dict[
int(self.selectivity_types[selectivity_type_name])
] = float(selectivity_type_prob)
selectivity_type_prob_dict[
pop
] = pop_selectivity_type_prob_dict
stimulus_config[
"Selectivity Type Probabilities"
] = selectivity_type_prob_dict
elif k == "Peak Rate":
peak_rate_dict = {}
for pop, dvals in v.items():
pop_peak_rate_dict = {}
for selectivity_type_name, peak_rate in dvals.items():
pop_peak_rate_dict[
int(self.selectivity_types[selectivity_type_name])
] = float(peak_rate)
peak_rate_dict[pop] = pop_peak_rate_dict
stimulus_config["Peak Rate"] = peak_rate_dict
elif k == "Arena":
stimulus_config["Arena"] = {}
for arena_id, arena_val in v.items():
arena_properties = {}
arena_domain = None
arena_trajectories = {}
for kk, vv in arena_val.items():
if kk == "Domain":
arena_domain = self.parse_arena_domain(vv)
elif kk == "Trajectory":
for name, trajectory_config in vv.items():
trajectory = self.parse_arena_trajectory(
trajectory_config
)
arena_trajectories[name] = trajectory
else:
arena_properties[kk] = vv
stimulus_config["Arena"][arena_id] = ArenaConfig(
arena_id,
arena_domain,
arena_trajectories,
arena_properties,
)
else:
stimulus_config[k] = v
self.stimulus_config = stimulus_config
[docs] def parse_netclamp_config(self):
"""
:return:
"""
netclamp_config_dict = self.model_config["Network Clamp"]
weight_generator_dict = netclamp_config_dict.get("Weight Generator", {})
template_param_rules_dict = netclamp_config_dict.get(
"Template Parameter Rules", {}
)
opt_param_rules_dict = {}
if "Synaptic Optimization" in netclamp_config_dict:
opt_param_rules_dict["synaptic"] = netclamp_config_dict[
"Synaptic Optimization"
]
template_params = {}
for template_name, params in template_param_rules_dict.items():
template_params[template_name] = params
self.netclamp_config = NetclampConfig(
template_params, weight_generator_dict, opt_param_rules_dict
)
def parse_origin_coords(self) -> None:
origin_spec = self.geometry["Parametric Surface"]["Origin"]
coords = {}
for key in ["U", "V", "L"]:
spec = origin_spec[key]
if isinstance(spec, float):
coords[key] = lambda x: spec
elif spec == "median":
coords[key] = lambda x: np.median(x)
elif spec == "mean":
coords[key] = lambda x: np.mean(x)
elif spec == "min":
coords[key] = lambda x: np.min(x)
elif spec == "max":
coords[key] = lambda x: np.max(x)
else:
raise ValueError
self.geometry["Parametric Surface"]["Origin"] = coords
def parse_definitions(self) -> None:
defs = self.model_config["Definitions"]
self.Populations = defs["Populations"]
self.SWC_Types = defs["SWC Types"]
self.Synapse_Types = defs["Synapse Types"]
self.layers = defs["Layers"]
self.selectivity_types = defs["Input Selectivity Types"]
def parse_globals(self) -> None:
self.globals = self.model_config["Global Parameters"]
def parse_syn_mechparams(
self,
mechparams_dict: Dict[
str, Union[Dict[str, Union[int, float]], Dict[str, float]]
],
) -> Dict[str, Union[Dict[str, Union[int, float]], Dict[str, float]]]:
res = {}
for mech_name, mech_params in mechparams_dict.items():
mech_params1 = {}
for k, v in mech_params.items():
if isinstance(v, dict):
if "expr" in v:
mech_params1[k] = ExprClosure(
[v["parameter"]],
v["expr"],
v.get("const", None),
["x"],
)
else:
raise RuntimeError(
f"parse_syn_mechparams: unknown parameter type {str(v)}"
)
else:
mech_params1[k] = v
res[mech_name] = mech_params1
return res
[docs] def parse_connection_config(self) -> None:
"""
:return:
"""
connection_config = self.model_config["Connection Generator"]
self.connection_velocity = connection_config["Connection Velocity"]
syn_mech_names = connection_config["Synapse Mechanisms"]
syn_param_rules = connection_config["Synapse Parameter Rules"]
self.synapse_attributes = SynapseAttributes(
self, syn_mech_names, syn_param_rules
)
extent_config = connection_config["Axon Extent"]
self.connection_extents = {}
for population in extent_config:
pop_connection_extents = {}
for layer_name in extent_config[population]:
if layer_name == "default":
pop_connection_extents[layer_name] = {
"width": extent_config[population][layer_name]["width"],
"offset": extent_config[population][layer_name][
"offset"
],
}
else:
layer_index = self.layers[layer_name]
pop_connection_extents[layer_index] = {
"width": extent_config[population][layer_name]["width"],
"offset": extent_config[population][layer_name][
"offset"
],
}
self.connection_extents[population] = pop_connection_extents
synapse_config = connection_config["Synapses"]
connection_dict = {}
for key_postsyn, val_syntypes in synapse_config.items():
connection_dict[key_postsyn] = {}
for key_presyn, syn_dict in val_syntypes.items():
val_type = syn_dict["type"]
val_synsections = syn_dict["sections"]
val_synlayers = syn_dict["layers"]
val_proportions = syn_dict["proportions"]
if "contacts" in syn_dict:
val_contacts = syn_dict["contacts"]
else:
val_contacts = 1
mechparams_dict = None
swctype_mechparams_dict = None
if "mechanisms" in syn_dict:
mechparams_dict = syn_dict["mechanisms"]
else:
swctype_mechparams_dict = syn_dict["swctype mechanisms"]
res_type = self.Synapse_Types[val_type]
res_synsections = []
res_synlayers = []
res_mechparams = {}
for name in val_synsections:
res_synsections.append(self.SWC_Types[name])
for name in val_synlayers:
res_synlayers.append(self.layers[name])
if swctype_mechparams_dict is not None:
for swc_type in swctype_mechparams_dict:
swc_type_index = self.SWC_Types[swc_type]
res_mechparams[
swc_type_index
] = self.parse_syn_mechparams(
swctype_mechparams_dict[swc_type]
)
else:
res_mechparams["default"] = self.parse_syn_mechparams(
mechparams_dict
)
connection_dict[key_postsyn][key_presyn] = SynapseConfig(
res_type,
res_synsections,
res_synlayers,
val_proportions,
val_contacts,
res_mechparams,
)
config_dict = defaultdict(lambda: 0.0)
for key_presyn, conn_config in connection_dict[key_postsyn].items():
for s, l, p in zip(
conn_config.sections,
conn_config.layers,
conn_config.proportions,
):
config_dict[(conn_config.type, s, l)] += p
for k, v in config_dict.items():
try:
assert np.isclose(v, 1.0)
except Exception as e:
self.logger.error(
f"Connection configuration: probabilities for {key_postsyn} do not sum to 1: type: {self.Synapse_Type_index[k[0]]} section: {self.SWC_Type_index[k[1]]} layer {self.layer_type_index[k[2]]} = {v}"
)
raise e
self.connection_config = connection_dict
[docs] def parse_gapjunction_config(self) -> None:
"""
:return:
"""
connection_config = self.model_config["Connection Generator"]
if "Gap Junctions" in connection_config:
gj_config = connection_config["Gap Junctions"]
gj_sections = gj_config["Locations"]
sections = {}
for pop_a, pop_dict in gj_sections.items():
for pop_b, sec_names in pop_dict.items():
pair = (pop_a, pop_b)
sec_idxs = []
for sec_name in sec_names:
sec_idxs.append(self.SWC_Types[sec_name])
sections[pair] = sec_idxs
gj_connection_probs = gj_config["Connection Probabilities"]
connection_probs = {}
for pop_a, pop_dict in gj_connection_probs.items():
for pop_b, prob in pop_dict.items():
pair = (pop_a, pop_b)
connection_probs[pair] = float(prob)
connection_weights_x = []
connection_weights_y = []
gj_connection_weights = gj_config["Connection Weights"]
for x in sorted(gj_connection_weights.keys()):
connection_weights_x.append(x)
connection_weights_y.append(gj_connection_weights[x])
connection_params = np.polyfit(
np.asarray(connection_weights_x),
np.asarray(connection_weights_y),
3,
)
connection_bounds = [
np.min(connection_weights_x),
np.max(connection_weights_x),
]
gj_coupling_coeffs = gj_config["Coupling Coefficients"]
coupling_coeffs = {}
for pop_a, pop_dict in gj_coupling_coeffs.items():
for pop_b, coeff in pop_dict.items():
pair = (pop_a, pop_b)
coupling_coeffs[pair] = float(coeff)
gj_coupling_weights = gj_config["Coupling Weights"]
coupling_weights_x = []
coupling_weights_y = []
for x in sorted(gj_coupling_weights.keys()):
coupling_weights_x.append(x)
coupling_weights_y.append(gj_coupling_weights[x])
coupling_params = np.polyfit(
np.asarray(coupling_weights_x),
np.asarray(coupling_weights_y),
3,
)
coupling_bounds = [
np.min(coupling_weights_x),
np.max(coupling_weights_x),
]
coupling_params = coupling_params
coupling_bounds = coupling_bounds
self.gapjunctions = {}
for pair, sec_idxs in sections.items():
self.gapjunctions[pair] = GapjunctionConfig(
sec_idxs,
connection_probs[pair],
connection_params,
connection_bounds,
coupling_coeffs[pair],
coupling_params,
coupling_bounds,
)
else:
self.gapjunctions = None
def load_node_rank_map(self, node_rank_file):
rank = 0
if self.comm is not None:
rank = self.comm.Get_rank()
node_rank_map = None
if rank == 0:
with open(node_rank_file) as fp:
dval = {}
lines = fp.readlines()
for l in lines:
a = l.split(" ")
dval[int(a[0])] = int(a[1])
node_rank_map = dval
node_rank_map = self.comm.bcast(node_rank_map, root=0)
pop_names = sorted(self.celltypes.keys())
self.node_allocation = set()
for pop_name in pop_names:
present = False
num = self.celltypes[pop_name]["num"]
start = self.celltypes[pop_name]["start"]
for gid in range(start, start + num):
if gid in node_rank_map:
present = True
if node_rank_map[gid] == rank:
self.node_allocation.add(gid)
if not present:
if rank == 0:
self.logger.warning(
"load_node_rank_map: gids assigned to population %s are not present in node ranks file %s; "
"gid to rank assignment will not be used"
% (pop_name, node_rank_file)
)
self.node_allocation = None
break
[docs] def load_celltypes(self) -> None:
"""
:return:
"""
rank = self.comm.Get_rank()
size = self.comm.Get_size()
celltypes = self.celltypes
typenames = sorted(celltypes.keys())
if rank == 0:
color = 1
else:
color = 0
## comm0 includes only rank 0
comm0 = self.comm.Split(color, 0)
if rank == 0:
self.logger.info(f"env.data_file_path = {str(self.data_file_path)}")
self.cell_attribute_info = None
population_ranges = None
population_names = None
if rank == 0:
population_names = read_population_names(self.data_file_path, comm0)
(population_ranges, _) = read_population_ranges(
self.data_file_path, comm0
)
self.cell_attribute_info = read_cell_attribute_info(
self.data_file_path, population_names, comm=comm0
)
self.logger.info(f"population_names = {str(population_names)}")
self.logger.info(f"population_ranges = {str(population_ranges)}")
self.logger.info(f"attribute info: {str(self.cell_attribute_info)}")
population_ranges = self.comm.bcast(population_ranges, root=0)
population_names = self.comm.bcast(population_names, root=0)
self.cell_attribute_info = self.comm.bcast(
self.cell_attribute_info, root=0
)
comm0.Free()
for k in typenames:
population_range = population_ranges.get(k, None)
if population_range is not None:
celltypes[k]["start"] = population_ranges[k][0]
celltypes[k]["num"] = population_ranges[k][1]
if "mechanism file" in celltypes[k]:
if isinstance(celltypes[k]["mechanism file"], str):
celltypes[k]["mech_file_path"] = celltypes[k][
"mechanism file"
]
mech_dict = None
if rank == 0:
mech_file_path = celltypes[k]["mech_file_path"]
if self.config_prefix is not None:
mech_file_path = os.path.join(
self.config_prefix, mech_file_path
)
mech_dict = read_from_yaml(mech_file_path)
else:
mech_dict = celltypes[k]["mechanism file"]
mech_dict = self.comm.bcast(mech_dict, root=0)
celltypes[k]["mech_dict"] = mech_dict
if "synapses" in celltypes[k]:
synapses_dict = celltypes[k]["synapses"]
if "weights" in synapses_dict:
weights_config = synapses_dict["weights"]
if isinstance(weights_config, list):
weights_dicts = weights_config
else:
weights_dicts = [weights_config]
for weights_dict in weights_dicts:
if "expr" in weights_dict:
expr = weights_dict["expr"]
parameter = weights_dict["parameter"]
const = weights_dict.get("const", {})
clos = ExprClosure(parameter, expr, const)
weights_dict["closure"] = clos
synapses_dict["weights"] = weights_dicts
def clear(self):
self.gidset = set()
self.gjlist = []
self.cells = defaultdict(dict)
self.artificial_cells = defaultdict(dict)
self.biophys_cells = defaultdict(dict)
self.recording_sets = {}
if self.pc is not None:
self.pc.gid_clear()
if self.t_vec is not None:
self.t_vec.resize(0)
if self.id_vec is not None:
self.id_vec.resize(0)
if self.t_rec is not None:
self.t_rec.resize(0)
self.recs_dict = {}
self.recs_count = 0
self.recs_pps_set = set()
for pop_name, _ in self.Populations.items():
self.recs_dict[pop_name] = defaultdict(list)