from typing import Optional, Tuple
import gc
import logging
import os
import os.path
import random
import sys
from collections import defaultdict
import h5py
import numpy as np
from miv_simulator.env import Env
from miv_simulator.geometry.geometry import (
get_layer_extents,
get_total_extents,
load_alpha_shape,
make_alpha_shape,
save_alpha_shape,
uvl_in_bounds,
)
from miv_simulator import config
from miv_simulator.utils import config_logging, get_script_logger
from miv_simulator.volume import (
make_network_volume,
network_volume,
network_volume_transform,
)
from mpi4py import MPI
from neuroh5.io import append_cell_attributes, read_population_ranges
from rbf.pde.geometry import contains
from rbf.pde.nodes import disperse, min_energy_nodes
from scipy.spatial import cKDTree
script_name = os.path.basename(__file__)
logger = get_script_logger(script_name)
sys_excepthook = sys.excepthook
def mpi_excepthook(type, value, traceback):
"""
:param type:
:param value:
:param traceback:
:return:
"""
sys_excepthook(type, value, traceback)
if MPI.COMM_WORLD.size > 1:
MPI.COMM_WORLD.Abort(1)
def random_subset(iterator, K):
result = []
N = 0
for item in iterator:
N += 1
if len(result) < K:
result.append(item)
else:
s = int(random.random() * N)
if s < K:
result[s] = item
return result
def gen_min_energy_nodes(
count, domain, constraint, nodeiter, dispersion_delta, snap_delta
):
N = int(count * 2) # layer-specific number of nodes
node_count = 0
while node_count < count:
# create N quasi-uniformly distributed nodes
def rho(x):
return np.ones(x.shape[0])
# nodes = rejection_sampling(N, rho, (vert, smp), start=0)
out = min_energy_nodes(
N,
domain,
iterations=nodeiter,
**{"dispersion_delta": dispersion_delta, "snap_delta": snap_delta},
)
nodes = out[0]
# remove nodes with nan
nodes1 = nodes[
~np.logical_or.reduce(
(
np.isnan(nodes[:, 0]),
np.isnan(nodes[:, 1]),
np.isnan(nodes[:, 2]),
)
)
]
if len(nodes) != len(nodes1):
logger.info(
f"{len(nodes) - len(nodes1)} nodes out of {len(nodes)} were NaN"
)
# remove nodes outside of the domain
vert, smp = domain
in_nodes = nodes[contains(nodes1, vert, smp)]
valid_idxs = None
if constraint is not None:
valid_idxs = []
current_xyz = in_nodes.reshape(-1, 3)
for i in range(len(current_xyz)):
if (
current_xyz[i][2] >= constraint[0]
and current_xyz[i][2] <= constraint[1]
):
valid_idxs.append(i)
if len(valid_idxs) == 0:
logger.info(
f"Warning: all in_nodes have been rejected due to constraint {constraint}!"
)
elif len(valid_idxs) != len(in_nodes):
logger.info(
f"Removing {len(in_nodes)-len(valid_idxs)} out of {len(in_nodes)} nodes due to constraint {constraint}"
)
in_nodes = in_nodes[valid_idxs]
node_count = len(in_nodes)
N = int(1.5 * N)
logger.info(
"%i interior nodes out of %i nodes generated"
% (node_count, len(nodes))
)
return in_nodes
# !for imperative API, use generate_network_architecture() instead
def generate_soma_coordinates(
config: str,
types_path: str,
output_path: str,
geometry_path: Optional[str] = None,
output_namespace: str = "Generated Coordinates",
populations: Tuple[str, ...] = (),
resolution: Tuple[int, int, int] = (3, 3, 3),
alpha_radius: float = 2500.0,
nodeiter: int = 10,
dispersion_delta: float = 0.1,
snap_delta: float = 0.01,
io_size: int = -1,
chunk_size: int = 1000,
value_chunk_size: int = 1000,
verbose: bool = False,
config_prefix="",
):
config_logging(verbose)
env = Env(comm=MPI.COMM_WORLD, config=config, config_prefix=config_prefix)
# Note that using linearly hardcoded random seeds
# is not recommended (https://dl.acm.org/doi/10.1145/1276927.1276928)
# and hence only part of this deprecated API
random_seed = int(env.model_config["Random Seeds"]["Soma Locations"])
random.seed(random_seed)
return generate_network_architecture(
output_filepath=output_path,
h5_types_filepath=types_path,
layer_extents=env.geometry["Parametric Surface"]["Layer Extents"],
rotation=env.geometry["Parametric Surface"]["Rotation"],
cell_distributions=env.geometry["Cell Distribution"],
cell_constraints=env.geometry.get("Cell Constraints", None),
output_namespace=output_namespace,
geometry_filepath=geometry_path,
populations=populations,
resolution=resolution,
alpha_radius=alpha_radius,
nodeiter=nodeiter,
dispersion_delta=dispersion_delta,
snap_delta=snap_delta,
io_size=io_size,
chunk_size=chunk_size,
value_chunk_size=value_chunk_size,
)
[docs]def generate_network_architecture(
output_filepath: str,
cell_distributions: config.CellDistributions,
layer_extents: config.LayerExtents,
rotation: config.Rotation,
cell_constraints: Optional[config.CellConstraints],
output_namespace: str,
geometry_filepath: Optional[str],
populations: Optional[Tuple[str, ...]],
resolution: Tuple[int, int, int],
alpha_radius: float,
nodeiter: int,
dispersion_delta: float,
snap_delta: float,
h5_types_filepath: Optional[str],
io_size: int,
chunk_size: int,
value_chunk_size: int,
):
logger = get_script_logger(script_name)
comm = MPI.COMM_WORLD
rank = comm.rank
size = comm.size
np.seterr(all="raise")
if io_size == -1:
io_size = comm.size
if rank == 0:
logger.info("%i ranks have been allocated" % comm.size)
if rank == 0:
if h5_types_filepath and not os.path.isfile(output_filepath):
input_file = h5py.File(h5_types_filepath, "r")
output_file = h5py.File(output_filepath, "w")
input_file.copy("/H5Types", output_file)
input_file.close()
output_file.close()
comm.barrier()
(extent_u, extent_v, extent_l) = get_total_extents(layer_extents)
vol = make_network_volume(
extent_u, extent_v, extent_l, rotate=rotation, resolution=resolution
)
layer_alpha_shape_path = "Layer Alpha Shape/%d/%d/%d" % tuple(resolution)
if rank == 0:
logger.info(
"Constructing alpha shape for volume: extents: %s..."
% str((extent_u, extent_v, extent_l))
)
vol_alpha_shape_path = f"{layer_alpha_shape_path}/all"
if geometry_filepath:
vol_alpha_shape = load_alpha_shape(
geometry_filepath, vol_alpha_shape_path
)
else:
vol_alpha_shape = make_alpha_shape(vol, alpha_radius=alpha_radius)
if geometry_filepath:
save_alpha_shape(
geometry_filepath, vol_alpha_shape_path, vol_alpha_shape
)
vert = vol_alpha_shape.points
smp = np.asarray(vol_alpha_shape.bounds, dtype=np.int64)
vol_domain = (vert, smp)
layer_alpha_shapes = {}
layer_extent_vals = {}
layer_extent_transformed_vals = {}
if rank == 0:
for layer, extents in layer_extents.items():
(extent_u, extent_v, extent_l) = get_layer_extents(
layer_extents, layer
)
layer_extent_vals[layer] = (extent_u, extent_v, extent_l)
layer_extent_transformed_vals[layer] = network_volume_transform(
extent_u, extent_v, extent_l
)
has_layer_alpha_shape = False
if geometry_filepath:
this_layer_alpha_shape_path = (
f"{layer_alpha_shape_path}/{layer}"
)
this_layer_alpha_shape = load_alpha_shape(
geometry_filepath, this_layer_alpha_shape_path
)
layer_alpha_shapes[layer] = this_layer_alpha_shape
if this_layer_alpha_shape is not None:
has_layer_alpha_shape = True
if not has_layer_alpha_shape:
logger.info(
"Constructing alpha shape for layers {}: extents: {}...".format(
layer, str(extents)
)
)
layer_vol = make_network_volume(
extent_u,
extent_v,
extent_l,
rotate=rotation,
resolution=resolution,
)
this_layer_alpha_shape = make_alpha_shape(
layer_vol, alpha_radius=alpha_radius
)
layer_alpha_shapes[layer] = this_layer_alpha_shape
if geometry_filepath:
save_alpha_shape(
geometry_filepath,
this_layer_alpha_shape_path,
this_layer_alpha_shape,
)
comm.barrier()
population_ranges = read_population_ranges(output_filepath, comm)[0]
if not populations:
populations = sorted(population_ranges.keys())
total_count = 0
for population in populations:
(population_start, population_count) = population_ranges[population]
total_count += population_count
all_xyz_coords1 = None
generated_coords_count_dict = defaultdict(int)
if rank == 0:
all_xyz_coords_lst = []
for population in populations:
gc.collect()
(population_start, population_count) = population_ranges[population]
pop_layers = cell_distributions[population]
pop_constraint = None
if cell_constraints is not None:
if population in cell_constraints:
pop_constraint = cell_constraints[population]
if rank == 0:
logger.info(
f"Population {population}: layer distribution is {pop_layers}"
)
pop_layer_count = 0
for layer, count in pop_layers.items():
pop_layer_count += count
if population_count != pop_layer_count:
logger.error(
f"Population {population}: mismatch between total count {population_count} and sum of per-layer counts {pop_layer_count}"
)
assert population_count == pop_layer_count
xyz_coords_lst = []
for layer, count in pop_layers.items():
if count <= 0:
continue
alpha = layer_alpha_shapes[layer]
vert = alpha.points
smp = np.asarray(alpha.bounds, dtype=np.int64)
extents_xyz = layer_extent_transformed_vals[layer]
for vvi, vv in enumerate(vert):
for vi, v in enumerate(vv):
if v < extents_xyz[vi][0]:
vert[vvi][vi] = extents_xyz[vi][0]
elif v > extents_xyz[vi][1]:
vert[vvi][vi] = extents_xyz[vi][1]
N = int(count * 2) # layer-specific number of nodes
node_count = 0
logger.info(
"Generating %i nodes in layer %s for population %s..."
% (N, layer, population)
)
if False: # verbose
rbf_logger = logging.Logger.manager.loggerDict[
"rbf.pde.nodes"
]
rbf_logger.setLevel(logging.DEBUG)
min_energy_constraint = None
if pop_constraint is not None and layer in pop_constraint:
min_energy_constraint = pop_constraint[layer]
nodes = gen_min_energy_nodes(
count,
(vert, smp),
min_energy_constraint,
nodeiter,
dispersion_delta,
snap_delta,
)
# nodes = gen_min_energy_nodes(count, (vert, smp),
# pop_constraint[layer] if pop_constraint is not None else None,
# nodeiter, dispersion_delta, snap_delta)
xyz_coords_lst.append(nodes.reshape(-1, 3))
for this_xyz_coords in xyz_coords_lst:
all_xyz_coords_lst.append(this_xyz_coords)
generated_coords_count_dict[population] += len(this_xyz_coords)
# Additional dispersion step to ensure no overlapping cell positions
all_xyz_coords = np.row_stack(all_xyz_coords_lst)
mask = np.ones((all_xyz_coords.shape[0],), dtype=np.bool_)
# distance to nearest neighbor
while True:
kdt = cKDTree(all_xyz_coords[mask, :])
nndist, nnindices = kdt.query(all_xyz_coords[mask, :], k=2)
nndist, nnindices = nndist[:, 1:], nnindices[:, 1:]
zindices = nnindices[
np.argwhere(np.isclose(nndist, 0.0, atol=1e-3, rtol=1e-3))
]
if len(zindices) > 0:
mask[np.argwhere(mask)[zindices]] = False
else:
break
coords_offset = 0
for population in populations:
pop_coords_count = generated_coords_count_dict[population]
pop_mask = mask[coords_offset : coords_offset + pop_coords_count]
generated_coords_count_dict[population] = np.count_nonzero(pop_mask)
coords_offset += pop_coords_count
logger.info("Dispersion of %i nodes..." % np.count_nonzero(mask))
all_xyz_coords1 = disperse(
all_xyz_coords[mask, :], vol_domain, delta=dispersion_delta
)
if rank == 0:
logger.info(
f"Computing UVL coordinates of {len(all_xyz_coords1)} nodes..."
)
all_xyz_coords_interp = None
all_uvl_coords_interp = None
if rank == 0:
all_uvl_coords_interp = vol.inverse(all_xyz_coords1)
all_xyz_coords_interp = (
vol(
all_uvl_coords_interp[:, 0],
all_uvl_coords_interp[:, 1],
all_uvl_coords_interp[:, 2],
mesh=False,
)
.reshape(3, -1)
.T
)
if rank == 0:
logger.info("Broadcasting generated nodes...")
xyz_coords = comm.bcast(all_xyz_coords1, root=0)
all_xyz_coords_interp = comm.bcast(all_xyz_coords_interp, root=0)
all_uvl_coords_interp = comm.bcast(all_uvl_coords_interp, root=0)
generated_coords_count_dict = comm.bcast(
dict(generated_coords_count_dict), root=0
)
coords_offset = 0
pop_coords_dict = {}
for population in populations:
xyz_error = np.asarray([0.0, 0.0, 0.0])
pop_layers = cell_distributions[population]
pop_start, pop_count = population_ranges[population]
coords = []
gen_coords_count = generated_coords_count_dict[population]
for i, coord_ind in enumerate(
range(coords_offset, coords_offset + gen_coords_count)
):
if i % size == rank:
uvl_coords = all_uvl_coords_interp[coord_ind, :].ravel()
xyz_coords1 = all_xyz_coords_interp[coord_ind, :].ravel()
if uvl_in_bounds(
all_uvl_coords_interp[coord_ind, :],
layer_extents,
pop_layers,
):
xyz_error = np.add(
xyz_error,
np.abs(
np.subtract(xyz_coords[coord_ind, :], xyz_coords1)
),
)
logger.info(
"Rank %i: %s cell %i: %f %f %f"
% (
rank,
population,
i,
uvl_coords[0],
uvl_coords[1],
uvl_coords[2],
)
)
coords.append(
(
xyz_coords1[0],
xyz_coords1[1],
xyz_coords1[2],
uvl_coords[0],
uvl_coords[1],
uvl_coords[2],
)
)
else:
logger.debug(
"Rank %i: %s cell %i not in bounds: %f %f %f"
% (
rank,
population,
i,
uvl_coords[0],
uvl_coords[1],
uvl_coords[2],
)
)
uvl_coords = None
xyz_coords1 = None
total_xyz_error = np.zeros((3,))
comm.Allreduce(xyz_error, total_xyz_error, op=MPI.SUM)
coords_count = 0
coords_count = np.sum(np.asarray(comm.allgather(len(coords))))
if coords_count == 0:
mean_xyz_error = np.asarray([0.0, 0.0, 0.0])
else:
mean_xyz_error = np.asarray(
[
(total_xyz_error[0] / coords_count),
(total_xyz_error[1] / coords_count),
(total_xyz_error[2] / coords_count),
]
)
pop_coords_dict[population] = coords
coords_offset += gen_coords_count
if rank == 0:
logger.info(
"Total %i coordinates generated for population %s: mean XYZ error: %f %f %f"
% (
coords_count,
population,
mean_xyz_error[0],
mean_xyz_error[1],
mean_xyz_error[2],
)
)
if rank == 0:
color = 1
else:
color = 0
## comm0 includes only rank 0
comm0 = comm.Split(color, 0)
for population in populations:
pop_start, pop_count = population_ranges[population]
pop_layers = cell_distributions[population]
pop_constraint = None
if cell_constraints is not None:
if population in cell_constraints:
pop_constraint = cell_constraints[population]
coords_lst = comm.gather(pop_coords_dict[population], root=0)
if rank == 0:
all_coords = []
for sublist in coords_lst:
for item in sublist:
all_coords.append(item)
coords_count = len(all_coords)
if coords_count < pop_count:
logger.warning(
"Generating additional %i coordinates for population %s..."
% (pop_count - len(all_coords), population)
)
safety = 0.01
delta = pop_count - len(all_coords)
for i in range(delta):
for layer, count in pop_layers.items():
if count > 0:
min_extent = layer_extents[layer][0]
max_extent = layer_extents[layer][1]
coord_u = np.random.uniform(
min_extent[0] + safety, max_extent[0] - safety
)
coord_v = np.random.uniform(
min_extent[1] + safety, max_extent[1] - safety
)
if (
pop_constraint is None
or layer not in pop_constraint
):
coord_l = np.random.uniform(
min_extent[2] + safety,
max_extent[2] - safety,
)
else:
coord_l = np.random.uniform(
pop_constraint[layer][0] + safety,
pop_constraint[layer][1] - safety,
)
xyz_coords = network_volume(
coord_u, coord_v, coord_l, rotate=rotation
).ravel()
all_coords.append(
(
xyz_coords[0],
xyz_coords[1],
xyz_coords[2],
coord_u,
coord_v,
coord_l,
)
)
sampled_coords = random_subset(all_coords, int(pop_count))
sampled_coords.sort(
key=lambda coord: coord[3]
) ## sort on U coordinate
coords_dict = {
pop_start
+ i: {
"X Coordinate": np.asarray([x_coord], dtype=np.float32),
"Y Coordinate": np.asarray([y_coord], dtype=np.float32),
"Z Coordinate": np.asarray([z_coord], dtype=np.float32),
"U Coordinate": np.asarray([u_coord], dtype=np.float32),
"V Coordinate": np.asarray([v_coord], dtype=np.float32),
"L Coordinate": np.asarray([l_coord], dtype=np.float32),
}
for (
i,
(x_coord, y_coord, z_coord, u_coord, v_coord, l_coord),
) in enumerate(sampled_coords)
}
append_cell_attributes(
output_filepath,
population,
coords_dict,
namespace=output_namespace,
io_size=io_size,
chunk_size=chunk_size,
value_chunk_size=value_chunk_size,
comm=comm0,
)
comm.barrier()
comm0.Free()