__doc__ = """
Routines for Network Clamp simulation.
"""
import gc
import os
import pprint
import sys
import time
from collections import defaultdict
import click
import h5py
import numpy as np
from miv_simulator import spikedata, stimulus, synapses
from miv_simulator.opto.run import OptoStim
from miv_simulator.clamps.cell import init_biophys_cell
from miv_simulator.cells import (
h,
is_cell_registered,
load_biophys_cell_dicts,
make_input_cell,
record_cell,
register_cell,
report_topology,
)
from miv_simulator.env import Env
from miv_simulator.utils.neuron import configure_hoc_env, h
from miv_simulator.stimulus import (
oscillation_phase_mod_config,
rate_maps_from_features,
)
from miv_simulator.utils import (
Struct,
Context,
config_logging,
generate_results_file_id,
get_low_pass_filtered_trace,
get_module_logger,
get_trial_time_indices,
is_interactive,
list_find,
list_index,
read_from_yaml,
write_to_yaml,
)
from miv_simulator.utils import io as io_utils
from mpi4py import MPI
from neuroh5.io import (
bcast_cell_attributes,
read_cell_attribute_info,
scatter_read_cell_attribute_selection,
)
# This logger will inherit its settings from the root logger, created in miv_simulator.env
logger = get_module_logger(__name__)
context = Context()
env = None
def set_union(s, t, datatype):
return s.union(t)
mpi_op_set_union = MPI.Op.Create(set_union, commute=True)
[docs]def mpi_excepthook(type, value, traceback):
"""
:param type:
:param value:
:param traceback:
:return:
"""
sys_excepthook(type, value, traceback)
sys.stdout.flush()
sys.stderr.flush()
if MPI.COMM_WORLD.size > 1:
MPI.COMM_WORLD.Abort(1)
sys_excepthook = sys.excepthook
sys.excepthook = mpi_excepthook
[docs]def generate_weights(env, weight_source_rules, this_syn_attrs):
"""
Generates synaptic weights according to the rules specified in the
Weight Generator section of network clamp configuration.
"""
weights_dict = {}
if len(weight_source_rules) > 0:
for presyn_id, weight_rule in weight_source_rules.items():
source_syn_dict = defaultdict(list)
for syn_id, syn in this_syn_attrs.items():
this_presyn_id = syn.source.population
this_presyn_gid = syn.source.gid
if this_presyn_id == presyn_id:
source_syn_dict[this_presyn_gid].append(syn_id)
if weight_rule["class"] == "Sparse":
weights_name = weight_rule["name"]
rule_params = weight_rule["params"]
fraction = rule_params["fraction"]
seed_offset = int(
env.model_config["Random Seeds"]["Sparse Weights"]
)
seed = int(seed_offset + 1)
weights_dict[presyn_id] = synapses.generate_sparse_weights(
weights_name, fraction, seed, source_syn_dict
)
elif weight_rule["class"] == "Log-Normal":
weights_name = weight_rule["name"]
rule_params = weight_rule["params"]
mu = rule_params["mu"]
sigma = rule_params["sigma"]
clip = None
if "clip" in rule_params:
clip = rule_params["clip"]
seed_offset = int(
env.model_config["Random Seeds"]["GC Log-Normal Weights 1"]
)
seed = int(seed_offset + 1)
weights_dict[presyn_id] = synapses.generate_log_normal_weights(
weights_name, mu, sigma, seed, source_syn_dict, clip=clip
)
elif weight_rule["class"] == "Normal":
weights_name = weight_rule["name"]
rule_params = weight_rule["params"]
mu = rule_params["mu"]
sigma = rule_params["sigma"]
seed_offset = int(
env.model_config["Random Seeds"]["GC Normal Weights"]
)
seed = int(seed_offset + 1)
weights_dict[presyn_id] = synapses.generate_normal_weights(
weights_name, mu, sigma, seed, source_syn_dict
)
else:
raise RuntimeError(
"network_clamp.generate_weights: unknown weight generator rule class "
f'{weight_rule["class"]}'
)
return weights_dict
[docs]def init(
env,
pop_name,
cell_index_set,
arena_id=None,
stimulus_id=None,
n_trials=1,
spike_events_path=None,
spike_events_namespace="Spike Events",
spike_train_attr_name="t",
input_features_path=None,
input_features_namespaces=None,
coords_path=None,
distances_namespace="Arc Distances",
phase_mod=False,
generate_weights_pops=set(),
t_min=None,
t_max=None,
write_cell=False,
plot_cell=False,
input_seed=None,
cooperative_init=False,
worker=None,
):
"""
Instantiates a cell and all its synapses and connections and loads
or generates spike times for all synaptic connections.
:param env: an instance of env.Env
:param pop_name: population name
:param gid_set: cell gids
:param spike_events_path:
"""
if phase_mod and coords_path is None:
raise RuntimeError(
"network_clamp.init: when phase_mod is True, coords_path must be provided"
)
if env.cell_selection is None:
env.cell_selection = {}
selection = env.cell_selection.get(pop_name, [])
env.cell_selection[pop_name] = list(cell_index_set) + [selection]
## If specified, presynaptic spikes that only fall within this time range
## will be loaded or generated
if t_max is None:
t_range = None
else:
if t_min is None:
t_range = [0.0, t_max]
else:
t_range = [t_min, t_max]
## Attribute namespace that contains recorded spike events
namespace_id = spike_events_namespace
my_cell_index_list = []
for i, gid in enumerate(cell_index_set):
if i % env.comm.size == env.comm.rank:
my_cell_index_list.append(gid)
my_cell_index_set = set(my_cell_index_list)
data_dict = None
cell_dict = None
if (worker is not None) and cooperative_init:
if worker.worker_id == 1:
cell_dict = load_biophys_cell_dicts(
env, pop_name, my_cell_index_set
)
req = worker.merged_comm.isend(
cell_dict, tag=InitMessageTag["cell"].value, dest=0
)
req.wait()
else:
cell_dict = worker.merged_comm.recv(
source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG
)
else:
cell_dict = load_biophys_cell_dicts(env, pop_name, my_cell_index_set)
## Load cell gid and its synaptic attributes and connection data
for gid in my_cell_index_set:
cell = init_biophys_cell(
env, pop_name, gid, cell_dict=cell_dict[gid], write_cell=write_cell
)
del cell_dict[gid]
pop_index_dict = {ind: name for name, ind in env.Populations.items()}
## Determine presynaptic populations that connect to this cell type
presyn_names = sorted(env.projection_dict[pop_name])
all_populations = [pop_name] + presyn_names
weight_source_dict = {}
for presyn_name in presyn_names:
presyn_index = int(env.Populations[presyn_name])
if presyn_name in generate_weights_pops:
if presyn_name in env.netclamp_config.weight_generators[pop_name]:
weight_rule = env.netclamp_config.weight_generators[pop_name][
presyn_name
]
else:
raise RuntimeError(
f"network_clamp.init: no weights generator rule specified for population {presyn_name}"
)
else:
weight_rule = None
if weight_rule is not None:
weight_source_dict[presyn_index] = weight_rule
min_delay = float("inf")
syn_attrs = env.synapse_attributes
presyn_sources = {presyn_name: set() for presyn_name in presyn_names}
for gid in my_cell_index_set:
this_syn_attrs = syn_attrs[gid]
for syn_id, syn in this_syn_attrs.items():
presyn_id = syn.source.population
if presyn_id is None:
raise RuntimeError(
f"gid {gid} synapse {syn_id} presyn_id is None: syn = {syn}"
)
presyn_name = pop_index_dict[presyn_id]
presyn_gid = syn.source.gid
presyn_sources[presyn_name].add(presyn_gid)
for presyn_name in presyn_names:
presyn_gid_set = env.comm.reduce(
presyn_sources[presyn_name], root=0, op=mpi_op_set_union
)
env.comm.barrier()
if env.comm.rank == 0:
presyn_gid_rank_dict = {
rank: set() for rank in range(env.comm.size)
}
for i, gid in enumerate(presyn_gid_set):
rank = i % env.comm.size
presyn_gid_rank_dict[rank].add(gid)
presyn_sources[presyn_name] = env.comm.scatter(
[
presyn_gid_rank_dict[rank]
for rank in sorted(presyn_gid_rank_dict)
],
root=0,
)
else:
presyn_sources[presyn_name] = env.comm.scatter(None, root=0)
env.comm.barrier()
soma_positions_dict = None
if coords_path is not None:
soma_positions_dict = {}
for population in all_populations:
reference_u_arc_distance_bounds = None
if env.comm.rank == 0:
with h5py.File(coords_path, "r") as coords_f:
reference_u_arc_distance_bounds = (
coords_f["Populations"][population][
distances_namespace
].attrs["Reference U Min"],
coords_f["Populations"][population][
distances_namespace
].attrs["Reference U Max"],
)
env.comm.barrier()
reference_u_arc_distance_bounds = env.comm.bcast(
reference_u_arc_distance_bounds, root=0
)
distances = bcast_cell_attributes(
coords_path, population, namespace=distances_namespace, root=0
)
abs_positions = {
k: v["U Distance"][0] - reference_u_arc_distance_bounds[0]
for (k, v) in distances
}
soma_positions_dict[population] = abs_positions
del distances
if env.opsin_config is not None:
opsin_pop_dict = {
pop_name: set(env.cells[pop_name].keys()).difference(
set(env.artificial_cells[pop_name].keys())
)
for pop_name in env.cells.keys()
}
rho_params = env.opsin_config["rho parameters"]
protocol_params = env.opsin_config["protocol parameters"]
env.opto_stim = OptoStim(
env.pc,
opsin_pop_dict,
model_nstates=env.opsin_config["nstates"],
opsin_type=env.opsin_config["opsin type"],
protocol=env.opsin_config["protocol"],
protocol_params=protocol_params,
rho_params=rho_params,
seed=int(env.model_config["Random Seeds"].get("Opsin", None)),
)
if rank == 0:
logger.info("*** Opsin configuration instantiated")
input_source_dict = None
if (worker is not None) and cooperative_init:
if worker.worker_id == 1:
if spike_events_path is not None:
input_source_dict = init_inputs_from_spikes(
env,
presyn_sources,
t_range,
spike_events_path,
spike_events_namespace,
arena_id,
stimulus_id,
spike_train_attr_name,
n_trials,
)
elif input_features_path is not None:
input_source_dict = init_inputs_from_features(
env,
presyn_sources,
t_range,
input_features_path,
input_features_namespaces,
arena_id=arena_id,
stimulus_id=stimulus_id,
spike_train_attr_name=spike_train_attr_name,
n_trials=n_trials,
seed=input_seed,
phase_mod=phase_mod,
soma_positions_dict=soma_positions_dict,
)
else:
raise RuntimeError(
"network_clamp.init: neither input spikes nor input features are provided"
)
req = worker.merged_comm.isend(
input_source_dict, tag=InitMessageTag["input"].value, dest=0
)
req.wait()
else:
input_source_dict = worker.merged_comm.recv(
source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG
)
else:
if spike_events_path is not None:
input_source_dict = init_inputs_from_spikes(
env,
presyn_sources,
t_range,
spike_events_path,
spike_events_namespace,
arena_id,
stimulus_id,
spike_train_attr_name,
n_trials,
)
elif input_features_path is not None:
input_source_dict = init_inputs_from_features(
env,
presyn_sources,
t_range,
input_features_path,
input_features_namespaces,
arena_id=arena_id,
stimulus_id=stimulus_id,
spike_train_attr_name=spike_train_attr_name,
n_trials=n_trials,
seed=input_seed,
phase_mod=phase_mod,
soma_positions_dict=soma_positions_dict,
)
else:
raise RuntimeError(
"network_clamp.init: neither input spikes nor input features are provided"
)
if t_range is not None:
env.tstart = t_range[0]
env.tstop = t_range[1]
env.comm.barrier()
for presyn_name in presyn_names:
presyn_gids = presyn_sources[presyn_name]
presyn_id = int(env.Populations[presyn_name])
if presyn_id not in input_source_dict:
continue
for presyn_gid in presyn_gids:
## Load presynaptic spike times into the VecStim for stimulus gid;
## if spike_generator_dict contains an entry for the respective presynaptic population,
## then use the given generator to generate spikes.
if not (
(presyn_gid in env.gidset)
or (is_cell_registered(env, presyn_gid))
):
cell = make_input_cell(
env,
presyn_gid,
presyn_id,
input_source_dict,
spike_train_attr_name=spike_train_attr_name,
)
register_cell(env, presyn_name, presyn_gid, cell)
for gid in my_cell_index_set:
synapses.config_biophys_cell_syns(
env, gid, pop_name, insert=True, insert_netcons=True, verbose=True
)
record_cell(env, pop_name, gid)
gc.collect()
if plot_cell:
from miv_simulator.plotting import plot_synaptic_attribute_distribution
syn_attrs = env.synapse_attributes
syn_name = "AMPA"
syn_mech_name = syn_attrs.syn_mech_names[syn_name]
for gid in my_cell_index_set:
biophys_cell = env.biophys_cells[pop_name][gid]
for param_name in ["weight", "g_unit"]:
param_label = f"{syn_name}; {syn_mech_name}; {param_name};"
plot_synaptic_attribute_distribution(
biophys_cell,
env,
syn_name,
param_name,
filters=None,
from_mech_attrs=True,
from_target_attrs=True,
param_label=param_label,
export=f"syn_params_{gid}.h5",
description="network_clamp",
show=False,
svg_title=f"Synaptic parameters for gid {gid}",
output_dir=env.results_path,
)
if env.verbose:
for gid in my_cell_index_set:
if is_cell_registered(env, gid):
cell = env.pc.gid2cell(gid)
for sec in list(
cell.hoc_cell.all if hasattr(cell, "hoc_cell") else cell.all
):
h.psection(sec=sec)
break
mindelay = env.pc.set_maxstep(10)
if is_interactive:
context.update(locals())
env.comm.barrier()
return my_cell_index_set
[docs]def run(env, cvode=False, pc_runworker=False):
"""
Runs network clamp simulation. Assumes that procedure `init` has been
called with the network configuration provided by the `env`
argument.
:param env: instance of env.Env
:param cvode: whether to use adaptive integration
"""
rank = int(env.pc.id())
nhosts = int(env.pc.nhost())
rec_dt = None
if env.recording_profile is not None:
rec_dt = env.recording_profile.get("dt", None)
if env.recs_count == 0:
## placeholder compartment to allow recording of time below
h("""create soma""")
if rec_dt is None:
env.t_rec.record(h._ref_t)
else:
env.t_rec.record(h._ref_t, rec_dt)
env.t_vec.resize(0)
env.id_vec.resize(0)
st_comptime = env.pc.step_time()
h.cvode_active(1 if cvode else 0)
h.t = float(env.tstart)
h.dt = env.dt
tstop = float(env.tstop)
if "Equilibration Duration" in env.stimulus_config:
tstop += float(env.stimulus_config["Equilibration Duration"])
h.tstop = float(env.n_trials) * tstop
h.finitialize(env.v_init)
if rank == 0:
logger.info(
f"*** Running simulation with dt = {h.dt:.04f} and tstop = {h.tstop:.02f}"
)
env.pc.barrier()
env.pc.psolve(h.tstop)
if rank == 0:
logger.info("*** Simulation completed")
env.pc.barrier()
comptime = env.pc.step_time() - st_comptime
avgcomp = env.pc.allreduce(comptime, 1) / nhosts
maxcomp = env.pc.allreduce(comptime, 2)
if rank == 0:
logger.info(f"Host {rank} ran simulation in {comptime:.02f} seconds")
if pc_runworker:
env.pc.runworker()
env.pc.done()
return spikedata.get_env_spike_dict(env, include_artificial=None)
def update_params(env, pop_param_dict):
for population, param_tuple_dict in pop_param_dict.items():
synapse_config = env.celltypes[population]["synapses"]
weights_dict = synapse_config.get("weights", {})
biophys_cell_dict = env.biophys_cells[population]
for gid, param_tuples in param_tuple_dict.items():
if gid not in biophys_cell_dict:
continue
biophys_cell = biophys_cell_dict[gid]
is_reduced = False
if hasattr(biophys_cell, "is_reduced"):
is_reduced = biophys_cell.is_reduced
for param_tuple, param_value in param_tuples:
assert population == param_tuple.population
source = param_tuple.source
sec_type = param_tuple.sec_type
syn_name = param_tuple.syn_name
param_path = param_tuple.param_path
if isinstance(param_path, list) or isinstance(
param_path, tuple
):
p, s = param_path
else:
p, s = param_path, None
sources = None
if isinstance(source, list) or isinstance(source, tuple):
sources = source
else:
if source is not None:
sources = [source]
if isinstance(sec_type, list) or isinstance(sec_type, tuple):
sec_types = sec_type
else:
sec_types = [sec_type]
for this_sec_type in sec_types:
synapses.modify_syn_param(
biophys_cell,
env,
this_sec_type,
syn_name,
param_name=p,
value={s: param_value}
if (s is not None)
else param_value,
filters={"sources": sources}
if sources is not None
else None,
origin=None if is_reduced else "soma",
update_targets=True,
)
[docs]def run_with(env, param_dict, cvode=False, pc_runworker=False):
"""
Runs network clamp simulation with the specified parameters for the given gid(s).
Assumes that procedure `init` has been called with
the network configuration provided by the `env` argument.
:param env: instance of env.Env
:param param_dict: dictionary { gid: params }
:param cvode: whether to use adaptive integration
"""
rank = int(env.pc.id())
nhosts = int(env.pc.nhost())
stash_id_dict = defaultdict(lambda: dict())
syn_attrs = env.synapse_attributes
for pop_name in param_dict:
for gid in param_dict[pop_name]:
stash_id = syn_attrs.stash_mech_attrs(pop_name, gid)
stash_id_dict[pop_name][gid] = stash_id
update_params(env, param_dict)
rec_dt = None
if env.recording_profile is not None:
rec_dt = env.recording_profile.get("dt", None)
if env.recs_count == 0:
## placeholder compartment to allow recording of time below
h("""create soma""")
if rec_dt is None:
env.t_rec.record(h._ref_t)
else:
env.t_rec.record(h._ref_t, rec_dt)
env.t_vec.resize(0)
env.id_vec.resize(0)
st_comptime = env.pc.step_time()
h.cvode_active(1 if cvode else 0)
h.t = float(env.tstart)
h.dt = env.dt
tstop = float(env.tstop)
if "Equilibration Duration" in env.stimulus_config:
tstop += float(env.stimulus_config["Equilibration Duration"])
h.tstop = float(env.n_trials) * tstop
h.finitialize(env.v_init)
h.finitialize(env.v_init)
if rank == 0:
logger.info(
f"*** Running simulation with dt = {h.dt:.04f} and tstop = {h.tstop:.02f}"
)
logger.info(f"*** Parameters: {pprint.pformat(param_dict)}")
env.pc.barrier()
env.pc.psolve(h.tstop)
if rank == 0:
logger.info("*** Simulation completed")
env.pc.barrier()
comptime = env.pc.step_time() - st_comptime
avgcomp = env.pc.allreduce(comptime, 1) / nhosts
maxcomp = env.pc.allreduce(comptime, 2)
if rank == 0:
logger.info(f"Host {rank} ran simulation in {comptime:.02f} seconds")
if pc_runworker:
env.pc.runworker()
env.pc.done()
for pop_name in param_dict:
for gid in param_dict[pop_name]:
stash_id = stash_id_dict[pop_name][gid]
syn_attrs.restore_mech_attrs(pop_name, gid, stash_id)
synapses.config_biophys_cell_syns(env, gid, pop_name, insert=False)
return spikedata.get_env_spike_dict(env, include_artificial=None)
def init_state_objfun(
config_file,
population,
cell_index_set,
arena_id,
stimulus_id,
generate_weights,
t_max,
t_min,
opt_iter,
template_paths,
dataset_prefix,
results_path,
spike_events_path,
spike_events_namespace,
spike_events_t,
input_features_path,
input_features_namespaces,
coords_path,
distances_namespace,
phase_mod,
n_trials,
trial_regime,
problem_regime,
param_type,
param_config_name,
recording_profile,
state_variable,
state_filter,
target_value,
use_coreneuron,
cooperative_init,
dt,
worker,
**kwargs,
):
params = dict(locals())
params["config"] = params.pop("config_file")
env = Env(**params)
env.results_file_path = None
configure_hoc_env(env, bcast_template=True)
my_cell_index_set = init(
env,
population,
cell_index_set,
arena_id,
stimulus_id,
n_trials,
spike_events_path,
spike_events_namespace=spike_events_namespace,
spike_train_attr_name=spike_events_t,
coords_path=coords_path,
distances_namespace=distances_namespace,
phase_mod=phase_mod,
input_features_path=input_features_path,
input_features_namespaces=input_features_namespaces,
generate_weights_pops=set(generate_weights),
t_min=t_min,
t_max=t_max,
cooperative_init=cooperative_init,
worker=worker,
)
time_step = env.stimulus_config["Temporal Resolution"]
equilibration_duration = float(
env.stimulus_config["Equilibration Duration"]
)
opt_param_config = optimization_params(
env.netclamp_config.optimize_parameters,
[population],
param_config_name,
param_type,
)
opt_targets = opt_param_config.opt_targets
param_names = opt_param_config.param_names
param_tuples = opt_param_config.param_tuples
recording_profile = {
"label": f"network_clamp.state.{state_variable}",
"dt": None if use_coreneuron else 0.1,
"section quantity": {state_variable: {"swc types": ["soma"]}},
}
env.recording_profile = recording_profile
state_recs_dict = {}
for gid in my_cell_index_set:
state_recs_dict[gid] = record_cell(
env, population, gid, recording_profile=recording_profile
)
def from_param_dict(params_dict):
result = []
for param_pattern, param_tuple in zip(param_names, param_tuples):
result.append((param_tuple, params_dict[param_pattern]))
return result
def gid_state_values(spkdict, t_offset, n_trials, t_rec, state_recs_dict):
t_vec = np.asarray(t_rec.to_python(), dtype=np.float32)
t_trial_inds = get_trial_time_indices(t_vec, n_trials, t_offset)
results_dict = {}
filter_fun = None
if state_filter == "lowpass":
filter_fun = lambda x, t: get_low_pass_filtered_trace(x, t)
for gid in state_recs_dict:
state_values = []
state_recs = state_recs_dict[gid]
for rec in state_recs:
vec = np.asarray(rec["vec"].to_python(), dtype=np.float32)
if filter_fun is None:
data = np.asarray(
[np.mean(vec[t_inds]) for t_inds in t_trial_inds]
)
else:
data = np.asarray(
[
np.mean(filter_fun(vec[t_inds], t_vec[t_inds]))
for t_inds in t_trial_inds
]
)
state_values.append(np.mean(data))
results_dict[gid] = state_values
return results_dict
def eval_problem(cell_param_dict, **kwargs):
state_values_dict = gid_state_values(
run_with(
env,
{
population: {
gid: from_param_dict(cell_param_dict[gid])
for gid in my_cell_index_set
}
},
),
equilibration_duration,
n_trials,
env.t_rec,
state_recs_dict,
)
if trial_regime == "mean":
return {
gid: -abs(np.mean(state_values_dict[gid]) - target_value)
for gid in my_cell_index_set
}
elif trial_regime == "best":
return {
gid: -(
np.min(
np.abs(
np.asarray(state_values_dict[gid]) - target_value
)
)
)
for gid in my_cell_index_set
}
else:
raise RuntimeError(
f"state_objfun: unknown trial regime {trial_regime}"
)
return opt_eval_fun(problem_regime, my_cell_index_set, eval_problem)
def init_rate_objfun(
config_file,
population,
cell_index_set,
arena_id,
stimulus_id,
n_trials,
trial_regime,
problem_regime,
generate_weights,
t_max,
t_min,
opt_iter,
template_paths,
dataset_prefix,
results_path,
spike_events_path,
spike_events_namespace,
spike_events_t,
coords_path,
distances_namespace,
phase_mod,
input_features_path,
input_features_namespaces,
param_type,
param_config_name,
recording_profile,
target_rate,
use_coreneuron,
cooperative_init,
dt,
worker,
**kwargs,
):
params = dict(locals())
params["config"] = params.pop("config_file")
env = Env(**params)
env.results_file_path = None
configure_hoc_env(env, bcast_template=True)
my_cell_index_set = init(
env,
population,
cell_index_set,
arena_id,
stimulus_id,
n_trials,
spike_events_path=spike_events_path,
spike_events_namespace=spike_events_namespace,
spike_train_attr_name=spike_events_t,
coords_path=coords_path,
distances_namespace=distances_namespace,
phase_mod=phase_mod,
input_features_path=input_features_path,
input_features_namespaces=input_features_namespaces,
generate_weights_pops=set(generate_weights),
t_min=t_min,
t_max=t_max,
cooperative_init=cooperative_init,
worker=worker,
)
time_range = (t_min if t_min is not None else 0.0, t_max)
tsecs = (time_range[1] - time_range[0]) / 1e3
time_step = env.stimulus_config["Temporal Resolution"]
equilibration_duration = float(
env.stimulus_config.get("Equilibration Duration", 0.0)
)
opt_param_config = optimization_params(
env.netclamp_config.optimize_parameters,
[population],
param_config_name,
param_type,
)
opt_targets = opt_param_config.opt_targets
param_names = opt_param_config.param_names
param_tuples = opt_param_config.param_tuples
recording_profile = {
"label": f"network_clamp.rate.v",
"dt": None if use_coreneuron else 0.1,
"section quantity": {"v": {"swc types": ["soma"]}},
}
env.recording_profile = recording_profile
state_recs_dict = {}
for gid in my_cell_index_set:
state_recs_dict[gid] = record_cell(
env, population, gid, recording_profile=recording_profile
)
target_v_threshold = opt_targets[f"{population} state"]["v"].get(
"threshold", None
)
target_v_margin = opt_targets[f"{population} state"]["v"].get(
"margin", -1.0
)
if target_v_threshold is None:
raise RuntimeError(
f"network_clamp: network clamp optimization configuration for population {population} "
f"must have state variable v threshold setting in section Targets"
)
def from_param_dict(params_dict):
result = []
for param_pattern, param_tuple in zip(param_names, param_tuples):
result.append((param_tuple, params_dict[param_pattern]))
return result
def gid_firing_rate(spkdict, cell_index_set):
rates_dict = defaultdict(list)
mean_rates_dict = {}
for i in range(n_trials):
spkdict1 = {}
for gid in cell_index_set:
if gid in spkdict[population]:
spk_ts = spkdict[population][gid][i]
spkdict1[gid] = spk_ts
else:
spkdict1[gid] = np.asarray([], dtype=np.float32)
for gid in cell_index_set:
this_rate = len(spkdict1[gid]) / tsecs
logger.info(
f"firing rate objective: spike times of gid {gid}: {pprint.pformat(spkdict1[gid])}"
)
logger.info(
f"firing rate objective: rate of gid {gid} is {this_rate:.02f}"
)
rates_dict[gid].append(this_rate)
return rates_dict
def gid_mean_v(t_offset, v_threshold, n_trials, t_rec, state_recs_dict):
t_vec = np.asarray(t_rec.to_python(), dtype=np.float32)
t_trial_inds = get_trial_time_indices(t_vec, n_trials, t_offset)
results_dict = {}
for gid in state_recs_dict:
state_values = []
state_recs = state_recs_dict[gid]
for rec in state_recs:
vec = np.asarray(rec["vec"].to_python(), dtype=np.float32)
data = np.asarray(
[
np.mean(np.clip(vec[t_inds], None, v_threshold))
for t_inds in t_trial_inds
]
)
state_values.append(data)
results_dict[gid] = state_values
return results_dict
def mean_rate_diff(gid, rates, target_rate):
rates_array = np.asarray(rates)
nz_idxs = np.argwhere(
np.logical_not(np.isclose(rates_array, 0.0, rtol=1e-4, atol=1e-4))
)
mean_rate = 0.0
if len(nz_idxs) > 0:
mean_rate = np.mean(rates_array[nz_idxs])
return abs(mean_rate - target_rate)
def best_rate_diff(gid, rates, target_rate):
rates_array = np.asarray(rates)
nz_idxs = np.argwhere(
np.logical_not(np.isclose(rates_array, 0.0, rtol=1e-4, atol=1e-4))
)
max_rate = 0.0
if len(nz_idxs) > 0:
max_rate = np.max(rates_array[nz_idxs])
return abs(max_rate - target_rate)
def eval_problem(cell_param_dict, **kwargs):
spkdict = run_with(
env,
{
population: {
gid: from_param_dict(cell_param_dict[gid])
for gid in my_cell_index_set
}
},
)
firing_rates_dict = gid_firing_rate(spkdict, my_cell_index_set)
mean_v_dict = gid_mean_v(
equilibration_duration,
target_v_threshold,
n_trials,
env.t_rec,
state_recs_dict,
)
if trial_regime == "mean":
objectives_dict = {
gid: -mean_rate_diff(gid, firing_rates_dict[gid], target_rate)
for gid in my_cell_index_set
}
elif trial_regime == "best":
objectives_dict = {
gid: -best_rate_diff(gid, firing_rates_dict[gid], target_rate)
for gid in my_cell_index_set
}
else:
raise RuntimeError(
f"rate_objfun: unknown trial regime {trial_regime}"
)
N_objectives = 1
opt_rate_feature_dtypes = [
("mean_rate", (np.float32, (1,))),
("trial_objs", (np.float32, (N_objectives, n_trials))),
("mean_v", (np.float32, (n_trials,))),
]
features_dict = {}
constraints_dict = {}
for gid in my_cell_index_set:
feature_array = np.empty(
shape=(1,), dtype=np.dtype(opt_rate_feature_dtypes)
)
rates_array = np.asarray(firing_rates_dict[gid])
nz_idxs = np.argwhere(
np.logical_not(
np.isclose(rates_array, 0.0, rtol=1e-4, atol=1e-4)
)
)
feature_array["mean_rate"] = 0.0
if len(nz_idxs) > 0:
feature_array["mean_rate"] = np.mean(rates_array[nz_idxs])
for i in range(N_objectives):
feature_array["trial_objs"][i, :] = rates_array
feature_array["mean_v"] = mean_v_dict[gid]
features_dict[gid] = feature_array
for gid in my_cell_index_set:
constraints_dict[gid] = np.asarray([1], dtype=np.int8)
if (
np.mean(features_dict[gid]["mean_v"])
>= target_v_threshold - target_v_margin
):
objectives_dict[gid] -= 1e6
constraints_dict[gid][0] = -1
return objectives_dict, features_dict, constraints_dict
return opt_eval_fun(problem_regime, my_cell_index_set, eval_problem)
def init_rate_dist_objfun(
config_file,
population,
cell_index_set,
arena_id,
stimulus_id,
n_trials,
trial_regime,
problem_regime,
generate_weights,
t_max,
t_min,
opt_iter,
template_paths,
dataset_prefix,
results_path,
spike_events_path,
spike_events_namespace,
spike_events_t,
coords_path,
distances_namespace,
phase_mod,
input_features_path,
input_features_namespaces,
param_type,
param_config_name,
recording_profile,
target_features_path,
target_features_namespace,
target_features_arena,
target_features_stimulus,
use_coreneuron,
cooperative_init,
dt,
worker,
**kwargs,
):
params = dict(locals())
params["config"] = params.pop("config_file")
env = Env(**params)
env.results_file_path = None
configure_hoc_env(env, bcast_template=True)
my_cell_index_set = init(
env,
population,
cell_index_set,
arena_id,
stimulus_id,
n_trials,
spike_events_path,
spike_events_namespace=spike_events_namespace,
spike_train_attr_name=spike_events_t,
coords_path=coords_path,
distances_namespace=distances_namespace,
phase_mod=phase_mod,
input_features_path=input_features_path,
input_features_namespaces=input_features_namespaces,
generate_weights_pops=set(generate_weights),
t_min=t_min,
t_max=t_max,
cooperative_init=cooperative_init,
worker=worker,
)
time_step = env.stimulus_config["Temporal Resolution"]
target_rate_vector_dict = rate_maps_from_features(
env,
population,
cell_index_set=my_cell_index_set,
input_features_path=target_features_path,
input_features_namespace=target_features_namespace,
time_range=None,
arena_id=arena_id,
)
for gid, target_rate_vector in target_rate_vector_dict.items():
target_rate_vector[
np.isclose(target_rate_vector, 0.0, atol=1e-3, rtol=1e-3)
] = 0.0
trj_d, trj_t = stimulus.read_stimulus(
input_features_path
if input_features_path is not None
else spike_events_path,
target_features_arena,
target_features_stimulus,
)
time_range = (0.0, min(np.max(trj_t), t_max))
time_bins = np.arange(time_range[0], time_range[1] + time_step, time_step)
opt_param_config = optimization_params(
env.netclamp_config.optimize_parameters,
[pop_name],
param_config_name,
param_type,
)
opt_targets = opt_param_config.opt_targets
param_names = opt_param_config.param_names
param_tuples = opt_param_config.param_tuples
def from_param_dict(params_dict):
result = []
for param_pattern, param_tuple in zip(param_names, param_tuples):
result.append((param_tuple, params_dict[param_pattern]))
return result
def gid_firing_rate_vectors(spkdict, cell_index_set):
rates_dict = defaultdict(list)
for i in range(n_trials):
spkdict1 = {}
for gid in cell_index_set:
if gid in spkdict[population]:
spkdict1[gid] = spkdict[population][gid][i]
else:
spkdict1[gid] = np.asarray([], dtype=np.float32)
spike_density_dict = spikedata.spike_density_estimate(
population, spkdict1, time_bins
)
for gid in cell_index_set:
rate_vector = spike_density_dict[gid]["rate"]
idxs = np.where(
np.isclose(rate_vector, 0.0, atol=1e-3, rtol=1e-3)
)[0]
rate_vector[idxs] = 0.0
rates_dict[gid].append(rate_vector)
for gid in spkdict[population]:
logger.info(
f"firing rate objective: trial {i} firing rate of gid {gid}: {spike_density_dict[gid]}"
)
logger.info(
f"firing rate objective: trial {i} firing rate min/max of gid {gid}: "
f"{np.min(rates_dict[gid]):.02f} / {np.max(rates_dict[gid]):.02f} Hz"
)
return rates_dict
def mean_trial_rate_mse(gid, rate_vectors, target_rate_vector):
mean_rate_vector = np.mean(np.row_stack(rate_vectors), axis=0)
logger.info(
f"firing rate objective: mean firing rate min/max of gid {gid}: "
f"{np.min(mean_rate_vector):.02f} / {np.max(mean_rate_vector):.02f} Hz"
)
return np.square(
np.subtract(mean_rate_vectore, target_rate_vector)
).mean()
def best_trial_rate_mse(gid, rate_vectors, target_rate_vector):
mses = []
for rate_vector in rate_vectors:
mse = np.square(np.subtract(rate_vector, target_rate_vector)).mean()
mses.append(mse)
min_mse_index = np.argmin(mses)
min_mse = mses[max_mse_index]
logger.info(
f"firing rate objective: max firing rate min/max of gid {gid}: "
f"{np.min(rate_vector[min_mse_index]):.02f} / {np.max(rate_vectors[min_mse_index]):.02f} Hz"
)
return min_mse
def eval_problem(cell_param_dict, **kwargs):
firing_rate_vectors_dict = gid_firing_rate_vectors(
run_with(
env,
{
population: {
gid: from_param_dict(cell_param_dict[gid])
for gid in my_cell_index_set
}
},
),
my_cell_index_set,
)
if trial_regime == "mean":
return {
gid: -mean_trial_rate_mse(
gid,
firing_rate_vectors_dict[gid],
target_rate_vector_dict[gid],
)
for gid in my_cell_index_set
}
elif trial_regime == "best":
return {
gid: -best_trial_rate_mse(
gid,
firing_rate_vectors_dict[gid],
target_rate_vector_dict[gid],
)
for gid in my_cell_index_set
}
else:
raise RuntimeError(
f"firing_rate_dist: unknown trial regime {trial_regime}"
)
return opt_eval_fun(problem_regime, my_cell_index_set, eval_problem)
def optimize_run(
env,
pop_name,
param_config_name,
init_objfun,
problem_regime,
nprocs_per_worker=1,
opt_iter=10,
solver_epsilon=1e-2,
opt_seed=None,
param_type="synaptic",
init_params={},
feature_dtypes=None,
constraint_names=None,
results_file=None,
cooperative_init=False,
verbose=False,
):
import distgfs
opt_param_config = optimization_params(
env.netclamp_config.optimize_parameters,
[pop_name],
param_config_name,
param_type,
)
opt_targets = opt_param_config.opt_targets
param_names = opt_param_config.param_names
param_tuples = opt_param_config.param_tuples
hyperprm_space = {
param_pattern: [param_tuple.param_range[0], param_tuple.param_range[1]]
for param_pattern, param_tuple in zip(param_names, param_tuples)
}
problem_metadata = np.array(
[
tuple(
opt_targets[k]
for k in sorted(opt_targets)
if isinstance(opt_targets[k], float)
)
],
dtype=[
(f"Target {k}", np.float32, (1,))
for k in sorted(opt_targets)
if isinstance(opt_targets[k], float)
],
)
if results_file is None:
if env.results_path is not None:
file_path = f"{env.results_path}/distgfs.network_clamp.{env.results_file_id}.h5"
else:
file_path = f"distgfs.network_clamp.{env.results_file_id}.h5"
else:
file_path = f"{env.results_path}/{results_file}"
problem_ids = None
reduce_fun_name = None
if ProblemRegime[problem_regime] == ProblemRegime.every:
if feature_dtypes is None:
reduce_fun_name = "opt_reduce_every"
elif (feature_dtypes is not None) and (constraint_names is None):
reduce_fun_name = "opt_reduce_every_features"
elif (feature_dtypes is not None) and (constraint_names is not None):
reduce_fun_name = "opt_reduce_every_features_constraints"
elif (feature_dtypes is None) and (constraint_names is not None):
reduce_fun_name = "opt_reduce_every_constraints"
problem_ids = init_params.get("cell_index_set", None)
elif ProblemRegime[problem_regime] == ProblemRegime.mean:
reduce_fun_name = "opt_reduce_mean"
feature_dtypes = None
elif ProblemRegime[problem_regime] == ProblemRegime.max:
reduce_fun_name = "opt_reduce_max"
feature_dtypes = None
else:
raise RuntimeError(
f"optimize_run: unknown problem regime {problem_regime}"
)
distgfs_params = {
"opt_id": "network_clamp.optimize",
"problem_ids": problem_ids,
"obj_fun_init_name": init_objfun,
"obj_fun_init_module": "miv_simulator.network_clamp",
"obj_fun_init_args": init_params,
"reduce_fun_name": reduce_fun_name,
"reduce_fun_module": "miv_simulator.optimization",
"problem_parameters": {},
"space": hyperprm_space,
"feature_dtypes": feature_dtypes,
"constraint_names": constraint_names,
"file_path": file_path,
"save": True,
"n_iter": opt_iter,
"seed": opt_seed,
"solver_epsilon": solver_epsilon,
"metadata": problem_metadata,
}
if cooperative_init:
distgfs_params["broker_fun_name"] = "distgfs_broker_init"
distgfs_params["broker_module_name"] = "miv_simulator.optimization"
opt_results = distgfs.run(
distgfs_params,
verbose=verbose,
collective_mode="sendrecv",
spawn_workers=True,
nprocs_per_worker=nprocs_per_worker,
)
if opt_results is not None:
if ProblemRegime[problem_regime] == ProblemRegime.every:
gid_results_config_dict = {}
for gid, opt_result in opt_results.items():
params_dict = dict(opt_result[0])
result_value = opt_result[1]
results_config_tuples = []
for param_pattern, param_tuple in zip(
param_names, param_tuples
):
results_config_tuples.append(
(
param_tuple.population,
param_tuple.source,
param_tuple.sec_type,
param_tuple.syn_name,
param_tuple.param_path,
params_dict[param_pattern],
)
)
gid_results_config_dict[int(gid)] = results_config_tuples
logger.info(
"Optimized parameters and objective function: "
f"{pprint.pformat(gid_results_config_dict)} @"
f"{result_value}"
)
return {pop_name: gid_results_config_dict}
else:
params_dict = dict(opt_results[0])
result_value = opt_results[1]
results_config_tuples = []
for param_pattern, param_tuple in zip(param_names, param_tuples):
results_config_tuples.append(
(
param_tuple.population,
param_tuple.source,
param_tuple.sec_type,
param_tuple.syn_name,
param_tuple.param_path,
params_dict[param_pattern],
)
)
logger.info(
"Optimized parameters and objective function: "
f"{pprint.pformat(results_config_tuples)} @"
f"{result_value}"
)
return {pop_name: results_config_tuples}
else:
return None
[docs]def dist_ctrl(
controller, init_params, cell_index_set, param_path, pop_param_tuple_dicts
):
"""Controller for distributed network clamp runs."""
task_ids = []
results_file_id = init_params.get("results_file_id", None)
if len(param_path) > 0:
for this_param_path, pop_param_tuple_dict in zip(
param_path, pop_param_tuple_dicts
):
params_basename = os.path.splitext(
os.path.basename(this_param_path)
)[0]
this_results_file_id = f"{results_file_id}_{params_basename}"
task_id = controller.submit_call(
"dist_run",
module_name="miv_simulator.network_clamp",
args=(
init_params,
cell_index_set,
this_results_file_id,
pop_param_tuple_dict,
),
)
task_ids.append(task_id)
else:
task_id = controller.submit_call(
"dist_run",
module_name="miv_simulator.network_clamp",
args=(init_params, cell_index_set, None, None),
)
task_ids.append(task_id)
for task_id in task_ids:
task_id, res = controller.get_next_result()
controller.info()
[docs]def dist_run(
init_params, cell_index_set, results_file_id=None, pop_param_tuple_dict=None
):
"""Initialize workers for distributed network clamp runs."""
if results_file_id is None:
results_file_id = init_params.get("results_file_id", None)
if results_file_id is None:
population = init_params["population"]
results_file_id = generate_results_file_id(
population, seed=init_params.get("opt_seed", None)
)
init_params["results_file_id"] = results_file_id
global env
if env is None:
init_params["config"] = init_params.pop("config_file", None)
env = Env(**init_params)
configure_hoc_env(env, bcast_template=True)
env.clear()
env.results_file_id = results_file_id
env.results_file_path = (
f"{env.results_path}/{env.modelName}_results_{env.results_file_id}.h5"
)
population = init_params["population"]
arena_id = init_params["arena_id"]
stimulus_id = init_params["stimulus_id"]
spike_events_path = init_params["spike_events_path"]
spike_events_namespace = init_params["spike_events_namespace"]
spike_events_t = init_params["spike_events_t"]
coords_path = init_params["coords_path"]
distances_namespace = init_params["distances_namespace"]
phase_mod = init_params["phase_mod"]
input_features_path = init_params["input_features_path"]
input_features_namespaces = init_params["input_features_namespaces"]
generate_weights = init_params.get("generate_weights", [])
t_min = init_params["t_min"]
t_max = init_params["t_max"]
n_trials = init_params["n_trials"]
input_seed = init_params.get("input_seed", None)
init(
env,
population,
cell_index_set,
arena_id,
stimulus_id,
n_trials,
spike_events_path,
spike_events_namespace=spike_events_namespace,
coords_path=coords_path,
distances_namespace=distances_namespace,
phase_mod=phase_mod,
spike_train_attr_name=spike_events_t,
input_features_path=input_features_path,
input_features_namespaces=input_features_namespaces,
generate_weights_pops=set(generate_weights),
input_seed=input_seed,
t_min=t_min,
t_max=t_max,
)
if pop_param_tuple_dict is not None:
run_with(env, pop_param_tuple_dict)
write_output(env)
write_params(env, pop_param_tuple_dict)
else:
run(env)
write_output(env)
return None
def write_output(env):
rank = env.comm.rank
if rank == 0:
io_utils.mkout(env, env.results_file_path)
env.comm.barrier()
if rank == 0:
logger.info("*** Writing spike data")
io_utils.spikeout(env, env.results_file_path)
if rank == 0:
logger.info("*** Writing intracellular data")
io_utils.recsout(
env,
env.results_file_path,
write_cell_location_data=True,
write_trial_data=True,
)
if rank == 0:
logger.info("*** Writing synapse spike counts")
for pop_name in sorted(env.biophys_cells.keys()):
presyn_names = sorted(env.projection_dict[pop_name])
synapses.write_syn_spike_count(
env,
pop_name,
env.results_file_path,
filters={"sources": presyn_names},
write_kwds={"io_size": env.io_size},
)
def write_params(env, pop_params_dict):
rank = env.comm.rank
if rank == 0:
logger.info("*** Writing synapse parameters")
output_pop_params_dict = {}
for this_pop_name, this_pop_param_dict in pop_params_dict.items():
this_pop_output_params_dict = {}
for this_gid, this_gid_param_list in this_pop_param_dict.items():
this_gid_param_dicts = []
for this_gid_param in this_gid_param_list:
syn_param, param_val = this_gid_param
this_gid_param_dicts.append(
{
"population": syn_param.population,
"source": syn_param.source,
"sec_type": syn_param.sec_type,
"syn_name": syn_param.syn_name,
"param_path": syn_param.param_path,
"param_val": float(param_val),
}
)
this_pop_output_params_dict[this_gid] = this_gid_param_dicts
output_pop_params_dict[this_pop_name] = this_pop_output_params_dict
io_utils.write_params(env.results_file_path, output_pop_params_dict)
env.comm.barrier()
[docs]def show(
config_file,
config_prefix,
population,
gid,
arena_id,
stimulus_id,
template_paths,
dataset_prefix,
results_path,
spike_events_path,
spike_events_namespace,
spike_events_t,
input_features_path,
input_features_namespaces,
use_coreneuron,
plot_cell,
write_cell,
profile_memory,
recording_profile,
):
"""
Show configuration for the specified cell.
"""
np.seterr(all="raise")
verbose = True
init_params = dict(locals())
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
if rank == 0:
comm0 = comm.Split(2 if rank == 0 else 1, 0)
init_params["config"] = init_params.pop("config_file", None)
env = Env(**init_params, comm=comm0)
configure_hoc_env(env)
init(
env,
population,
{gid},
arena_id,
stimulus_id,
spike_events_path=spike_events_path,
spike_events_namespace=spike_events_namespace,
spike_train_attr_name=spike_events_t,
input_features_path=input_features_path,
input_features_namespaces=input_features_namespaces,
plot_cell=plot_cell,
write_cell=write_cell,
)
cell = env.biophys_cells[population][gid]
logger.info(pprint.pformat(report_topology(cell, env)))
if env.profile_memory:
profile_memory(logger)
comm.barrier()
[docs]def go(
config_file,
config_prefix,
population,
dt,
gids,
gid_selection_file,
arena_id,
stimulus_id,
generate_weights,
t_max,
t_min,
template_paths,
dataset_prefix,
spike_events_path,
spike_events_namespace,
spike_events_t,
coords_path,
distances_namespace,
phase_mod,
input_features_path,
input_features_namespaces,
n_trials,
params_path,
params_id,
results_path,
results_file_id,
results_namespace_id,
use_coreneuron,
plot_cell,
write_cell,
profile_memory,
recording_profile,
input_seed,
):
"""
Runs network clamp simulation for the specified gid, or for all gids found in the input data file.
"""
init_params = dict(locals())
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
np.seterr(all="raise")
verbose = True
init_params["verbose"] = verbose
config_logging(verbose)
pop_params_tuple_dicts = None
if rank == 0:
if results_file_id is None:
results_file_id = generate_results_file_id(
population, seed=input_seed
)
if len(params_path) > 0:
pop_params_tuple_dicts = []
if len(params_id) == 0:
params_id = [None] * len(params_path)
for this_params_path, this_param_id in zip(params_path, params_id):
pop_params_dict = read_from_yaml(this_params_path)
pop_params_tuple_dict = {}
for (
this_pop_name,
this_pop_param_dict,
) in pop_params_dict.items():
this_pop_params_tuple_dict = defaultdict(list)
for (
this_gid,
this_gid_params,
) in this_pop_param_dict.items():
if this_param_id is not None:
this_gid_params_list = this_gid_params[
this_param_id
]
else:
this_gid_params_list = this_gid_params
for this_gid_param in this_gid_params_list:
(
this_population,
source,
sec_type,
syn_name,
param_path,
param_val,
) = this_gid_param
syn_param = SynParam(
this_population,
source,
sec_type,
syn_name,
param_path,
None,
)
this_pop_params_tuple_dict[this_gid].append(
(syn_param, param_val)
)
pop_params_tuple_dict[this_pop_name] = dict(
this_pop_params_tuple_dict
)
pop_params_tuple_dicts.append(pop_params_tuple_dict)
results_file_id = comm.bcast(results_file_id, root=0)
init_params["results_file_id"] = results_file_id
pop_params_tuple_dicts = comm.bcast(pop_params_tuple_dicts, root=0)
cell_index_set = set()
if gid_selection_file is not None:
with open(gid_selection_file) as f:
lines = f.readlines()
for line in lines:
gid = int(line)
cell_index_set.add(gid)
elif gids is not None:
for gid in gids:
cell_index_set.add(gid)
else:
comm.barrier()
comm0 = comm.Split(2 if rank == 0 else 1, 0)
if rank == 0:
init_params["config"] = init_params.pop("config_file", None)
env = Env(**init_params, comm=comm0)
attr_info_dict = read_cell_attribute_info(
env.data_file_path,
populations=[population],
read_cell_index=True,
comm=comm0,
)
cell_index = None
attr_name, attr_cell_index = next(
iter(attr_info_dict[population]["Trees"])
)
cell_index_set = set(attr_cell_index)
comm.barrier()
cell_index_set = comm.bcast(cell_index_set, root=0)
comm.barrier()
comm0.Free()
if size > 1:
import distwq
if distwq.is_controller:
distwq.run(
fun_name="dist_ctrl",
module_name="miv_simulator.network_clamp",
verbose=True,
args=(
init_params,
cell_index_set,
params_path,
pop_params_tuple_dicts,
),
spawn_workers=True,
nprocs_per_worker=1,
)
else:
distwq.run(verbose=True, spawn_workers=True, nprocs_per_worker=1)
else:
init_params["config"] = init_params.pop("config_file", None)
env = Env(**init_params, comm=comm)
configure_hoc_env(env)
init(
env,
population,
cell_index_set,
arena_id,
stimulus_id,
n_trials,
spike_events_path,
spike_events_namespace=spike_events_namespace,
spike_train_attr_name=spike_events_t,
coords_path=coords_path,
distances_namespace=distances_namespace,
phase_mod=phase_mod,
input_features_path=input_features_path,
input_features_namespaces=input_features_namespaces,
generate_weights_pops=set(generate_weights),
t_min=t_min,
t_max=t_max,
input_seed=input_seed,
plot_cell=plot_cell,
write_cell=write_cell,
)
if pop_params_tuple_dicts is not None:
for this_params_path, pop_params_tuple_dict in zip(
params_path, pop_params_tuple_dicts
):
params_basename = os.path.splitext(
os.path.basename(this_params_path)
)[0]
env.results_file_id = f"{results_file_id}_{params_basename}"
env.results_file_path = f"{env.results_path}/{env.modelName}_results_{env.results_file_id}.h5"
run_with(env, pop_params_tuple_dict)
write_output(env)
write_params(env, pop_params_tuple_dict)
else:
run(env)
write_output(env)
if env.profile_memory:
profile_memory(logger)
[docs]def optimize(
config_file,
config_prefix,
population,
dt,
gids,
gid_selection_file,
arena_id,
stimulus_id,
generate_weights,
t_max,
t_min,
nprocs_per_worker,
opt_epsilon,
opt_seed,
opt_iter,
template_paths,
dataset_prefix,
param_config_name,
param_type,
recording_profile,
results_file,
results_path,
spike_events_path,
spike_events_namespace,
spike_events_t,
coords_path,
distances_namespace,
phase_mod,
input_features_path,
input_features_namespaces,
n_trials,
trial_regime,
problem_regime,
target_features_path,
target_features_namespace,
target_state_variable,
target_state_filter,
use_coreneuron,
cooperative_init,
target,
):
"""
Optimize the firing rate of the specified cell in a network clamp configuration.
"""
init_params = dict(locals())
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
results_file_id = None
if rank == 0:
ts = time.strftime("%Y%m%d_%H%M%S")
opt_seed_lab = (
f"NOS{np.random.randint(99999999):08d}"
if opt_seed is None
else f"{opt_seed:08d}"
)
results_file_id = f"{population!s}_{ts!s}_{opt_seed_lab!s}"
results_file_id = comm.bcast(results_file_id, root=0)
comm.barrier()
np.seterr(all="raise")
verbose = True
cache_queries = True
cell_index_set = set()
if gid_selection_file is not None:
with open(gid_selection_file) as f:
lines = f.readlines()
for line in lines:
gid = int(line)
cell_index_set.add(gid)
elif gids is not None:
for gid in gids:
cell_index_set.add(gid)
else:
comm.barrier()
comm0 = comm.Split(2 if rank == 0 else 1, 0)
if rank == 0:
init_params["config"] = init_params.pop("config_file", None)
env = Env(**init_params, comm=comm0)
attr_info_dict = read_cell_attribute_info(
env.data_file_path,
populations=[population],
read_cell_index=True,
comm=comm0,
)
cell_index = None
attr_name, attr_cell_index = next(
iter(attr_info_dict[population]["Trees"])
)
cell_index_set = set(attr_cell_index)
comm.barrier()
cell_index_set = comm.bcast(cell_index_set, root=0)
comm.barrier()
comm0.Free()
init_params["cell_index_set"] = cell_index_set
del init_params["gids"]
N_objectives = 1
opt_rate_feature_dtypes = [
("mean_rate", (np.float32, (1,))),
("trial_objs", (np.float32, (N_objectives, n_trials))),
("mean_v", (np.float32, (n_trials,))),
]
params = dict(locals())
params["config"] = params.pop("config_file", None)
env = Env(**params)
if size == 1:
configure_hoc_env(env)
init(
env,
population,
cell_index_set,
arena_id,
stimulus_id,
n_trials,
spike_events_path,
spike_events_namespace=spike_events_namespace,
spike_train_attr_name=spike_events_t,
coords_path=coords_path,
distances_namespace=distances_namespace,
phase_mod=phase_mod,
input_features_path=input_features_path,
input_features_namespaces=input_features_namespaces,
generate_weights_pops=set(generate_weights),
t_min=t_min,
t_max=t_max,
)
if population in env.netclamp_config.optimize_parameters[param_type]:
opt_params = env.netclamp_config.optimize_parameters[param_type][
population
]
else:
raise RuntimeError(
f"network_clamp.optimize: population {population} does not have optimization configuration"
)
if target == "rate":
opt_target = opt_params["Targets"]["firing rate"]
init_params["target_rate"] = opt_target
init_objfun_name = "init_rate_objfun"
feature_dtypes = opt_rate_feature_dtypes
constraint_names = ["mean_v_below_threshold"]
elif target == "state":
assert target_state_variable is not None
opt_target = opt_params["Targets"]["state"][target_state_variable][
"mean"
]
init_params["target_value"] = opt_target
init_params["state_variable"] = target_state_variable
init_params["state_filter"] = target_state_filter
init_objfun_name = "init_state_objfun"
feature_dtypes = None
constraint_names = None
elif target == "ratedist" or target == "rate_dist":
init_params["target_features_arena"] = arena_id
init_params["target_features_trajectory"] = trajectory_id
init_objfun_name = "init_rate_dist_objfun"
feature_dtypes = None
constraint_names = None
else:
raise RuntimeError(
f"network_clamp.optimize: unknown optimization target {target}"
)
results_config_dict = optimize_run(
env,
population,
param_config_name,
init_objfun_name,
problem_regime=problem_regime,
opt_iter=opt_iter,
solver_epsilon=opt_epsilon,
opt_seed=opt_seed,
param_type=param_type,
init_params=init_params,
feature_dtypes=feature_dtypes,
constraint_names=constraint_names,
results_file=results_file,
nprocs_per_worker=nprocs_per_worker,
cooperative_init=cooperative_init,
verbose=verbose,
)
if results_config_dict is not None:
if results_path is not None:
file_path = (
f"{results_path}/network_clamp.optimize.{results_file_id}.yaml"
)
write_to_yaml(file_path, results_config_dict)
comm.barrier()