from typing import List, Optional, Any, Tuple
import copy
import sys
import time
from collections import defaultdict
from mpi4py import MPI
import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import gridspec
from matplotlib.colors import BoundaryNorm
from matplotlib.offsetbox import AnchoredText
from matplotlib.ticker import MaxNLocator
from matplotlib.animation import FuncAnimation, writers
from miv_simulator import cells, spikedata, statedata, stimulus, synapses
from miv_simulator.volume import network_volume
from miv_simulator.env import Env
from miv_simulator.utils import (
Struct,
apply_filter,
butter_bandpass_filter,
get_module_logger,
make_geometric_graph,
signal_power_spectrogram,
signal_psd,
zip_longest,
add_bins,
update_bins,
finalize_bins,
get_low_pass_filtered_trace,
)
from miv_simulator.utils.neuron import h, interplocs
from mpl_toolkits.axes_grid1 import make_axes_locatable
from neuroh5.io import (
read_cell_attributes,
read_population_names,
read_population_ranges,
NeuroH5ProjectionGen,
bcast_cell_attributes,
)
from scipy import interpolate, ndimage, signal
if hasattr(h, "nrnmpi_init"):
h.nrnmpi_init()
# This logger will inherit its settings from the root logger, created in miv_simulator.env
logger = get_module_logger(__name__)
# Default figure configuration
default_fig_options = Struct(
figFormat="png",
lw=2,
figSize=(10, 8),
fontSize=14,
saveFig=None,
showFig=True,
colormap="jet",
saveFigDir=None,
)
dflt_colors = [
"#009BFF",
"#E85EBE",
"#00FF00",
"#0000FF",
"#FF0000",
"#01FFFE",
"#FFA6FE",
"#FFDB66",
"#006401",
"#010067",
"#95003A",
"#007DB5",
"#FF00F6",
"#FFEEE8",
"#774D00",
"#90FB92",
"#0076FF",
"#D5FF00",
"#FF937E",
"#6A826C",
"#FF029D",
"#FE8900",
"#7A4782",
"#7E2DD2",
"#85A900",
"#FF0056",
"#A42400",
"#00AE7E",
"#683D3B",
"#BDC6FF",
"#263400",
"#BDD393",
"#00B917",
"#9E008E",
"#001544",
"#C28C9F",
"#FF74A3",
"#01D0FF",
"#004754",
"#E56FFE",
"#788231",
"#0E4MIV",
"#91D0CB",
"#BE9970",
"#968AE8",
"#BB8800",
"#43002C",
"#DEFF74",
"#00FFC6",
"#FFE502",
"#620E00",
"#008F9C",
"#98FF52",
"#7544B1",
"#B500FF",
"#00FF78",
"#FF6E41",
"#005F39",
"#6B6882",
"#5FAD4E",
"#A75740",
"#A5FFD2",
"#FFB167",
]
rainbow_colors = [
"#9400D3",
"#4B0082",
"#00FF00",
"#FFFF00",
"#FF7F00",
"#FF0000",
]
raster_colors = [
"#8dd3c7",
"#ffed6f",
"#bebada",
"#fb8072",
"#80b1d3",
"#fdb462",
"#b3de69",
"#fccde5",
"#d9d9d9",
"#bc80bd",
"#ccebc5",
]
def hex2rgb(hexcode):
if hasattr(hexcode, "decode"):
return tuple(
float(b) / 255.0 for b in map(ord, hexcode[1:].decode("hex"))
)
else:
import codecs
bhexcode = bytes(hexcode[1:], "utf-8")
return tuple(float(b) / 255.0 for b in codecs.decode(bhexcode, "hex"))
def show_figure():
plt.show()
def close_figure(fig):
plt.close(fig)
[docs]def plot_graph(
x, y, z, start_idx, end_idx, edge_scalars=None, edge_color=None, **kwargs
):
"""
Shows graph edges using Mayavi
Parameters
----------
x: ndarray
x coordinates of the points
y: ndarray
y coordinates of the points
z: ndarray
z coordinates of the points
edge_scalars: ndarray, optional
optional data to give the color of the edges.
kwargs:
extra keyword arguments are passed to quiver3d.
"""
from mayavi import mlab
if edge_color is not None:
kwargs["color"] = edge_color
vec = mlab.quiver3d(
x[start_idx],
y[start_idx],
z[start_idx],
x[end_idx] - x[start_idx],
y[end_idx] - y[start_idx],
z[end_idx] - z[start_idx],
scalars=edge_scalars,
scale_factor=1,
mode="2ddash",
**kwargs,
)
b = mlab.points3d(x[0], y[0], z[0], mode="cone", scale_factor=3, **kwargs)
if edge_scalars is not None:
vec.glyph.color_mode = "color_by_scalar"
cb = mlab.colorbar(vec, label_fmt="%.1f")
cb.label_text_property.font_size = 14
return vec
[docs]def clean_axes(axes, left=True, right=False):
"""
Remove top and right axes from pyplot axes object.
:param axes: list of pyplot.Axes
:param top: bool
:param left: bool
:param right: bool
"""
if not type(axes) in [np.ndarray, list]:
axes = [axes]
elif type(axes) == np.ndarray:
axes = axes.flatten()
for axis in axes:
axis.tick_params(direction="out")
axis.spines["top"].set_visible(False)
if not right:
axis.spines["right"].set_visible(False)
if not left:
axis.spines["left"].set_visible(False)
axis.get_xaxis().tick_bottom()
axis.get_yaxis().tick_left()
def plot_spatial_bin_graph(graph_dict, **kwargs):
import hiveplot as hv
edge_dflt_colors = ["red", "crimson", "coral", "purple"]
fig_options = copy.copy(default_fig_options)
fig_options.update(kwargs)
label = graph_dict["label"]
GU = graph_dict["U graph"]
destination = graph_dict["destination"]
sources = graph_dict["sources"]
nodes = {}
nodes[destination] = [(s, d) for s, d in GU.nodes() if s == destination]
for source in sources:
nodes[source] = [(s, d) for s, d in GU.nodes() if s == source]
snodes = {}
for group, nodelist in nodes.items():
snodes[group] = sorted(nodelist)
edges = {}
for source in sources:
edges[source] = [
(u, v, d) for u, v, d in GU.edges(data=True) if v[0] == source
]
nodes_cmap = dict()
nodes_cmap[destination] = "blue"
for i, source in enumerate(sources):
nodes_cmap[source] = raster_colors[i]
edges_cmap = dict()
for i, source in enumerate(sources):
edges_cmap[source] = dflt_colors[i]
hvpl = hv.HivePlot(snodes, edges, nodes_cmap, edges_cmap)
hvpl.draw()
filename = f"{label}.{fig_options.figFormat}"
plt.savefig(filename)
[docs]def plot_coordinates(
coords_path,
population,
namespace,
index=0,
graph_type="scatter",
bin_size=0.01,
xyz=False,
**kwargs,
):
"""
Plot coordinates
:param coords_path:
:param namespace:
:param population:
"""
fig_options = copy.copy(default_fig_options)
fig_options.update(kwargs)
soma_coords = read_cell_attributes(
coords_path, population, namespace=namespace
)
fig = plt.figure(1, figsize=plt.figaspect(1.0) * 2.0)
ax = plt.gca()
coord_U = {}
coord_V = {}
if xyz:
for k, v in soma_coords:
coord_U[k] = v["X Coordinate"][index]
coord_V[k] = v["Y Coordinate"][index]
else:
for k, v in soma_coords:
coord_U[k] = v["U Coordinate"][index]
coord_V[k] = v["V Coordinate"][index]
coord_U_array = np.asarray([coord_U[k] for k in sorted(coord_U.keys())])
coord_V_array = np.asarray([coord_V[k] for k in sorted(coord_V.keys())])
x_min = np.min(coord_U_array)
x_max = np.max(coord_U_array)
y_min = np.min(coord_V_array)
y_max = np.max(coord_V_array)
dx = int((x_max - x_min) / bin_size)
dy = int((y_max - y_min) / bin_size)
if graph_type == "scatter":
ax.scatter(coord_U_array, coord_V_array, alpha=0.1, linewidth=0)
ax.axis([x_min, x_max, y_min, y_max])
elif graph_type == "histogram2d":
(H, xedges, yedges) = np.histogram2d(
coord_U_array, coord_V_array, bins=[dx, dy]
)
X, Y = np.meshgrid(xedges, yedges)
Hint = H[:-1, :-1]
levels = MaxNLocator(nbins=25).tick_values(Hint.min(), Hint.max())
cmap = plt.get_cmap("jet")
norm = BoundaryNorm(levels, ncolors=cmap.N, clip=True)
p = ax.contourf(
X[:-1, :-1] + (bin_size / 2),
Y[:-1, :-1] + (bin_size / 2),
H.T,
levels=levels,
cmap=cmap,
)
fig.colorbar(p, ax=ax, shrink=0.5, aspect=20)
else:
raise ValueError(f"Unknown graph type {graph_type}")
if xyz:
ax.set_xlabel("X coordinate (um)", fontsize=fig_options.fontSize)
ax.set_ylabel("Y coordinate (um)", fontsize=fig_options.fontSize)
else:
ax.set_xlabel(
"U coordinate (septal - temporal)", fontsize=fig_options.fontSize
)
ax.set_ylabel(
"V coordinate (supra - infrapyramidal)",
fontsize=fig_options.fontSize,
)
ax.set_title(
f"Coordinate distribution for population: {population}",
fontsize=fig_options.fontSize,
)
if fig_options.saveFig:
if isinstance(fig_options.saveFig, str):
filename = fig_options.saveFig
else:
filename = f"{population} Coordinates.{fig_options.figFormat}"
plt.savefig(filename)
if fig_options.showFig:
show_figure()
return ax
# !needs refactoring
def plot_coords_in_volume(
populations,
coords_path,
coords_namespace,
config,
scale=25.0,
subpopulation=-1,
subvol=False,
verbose=False,
mayavi=False,
config_prefix="",
):
from miv_simulator.geometry.geometry import get_total_extents
env = Env(config=config, config_prefix=config_prefix)
rotate = env.geometry["Parametric Surface"]["Rotation"]
layer_extents = env.geometry["Parametric Surface"]["Layer Extents"]
rotate = env.geometry["Parametric Surface"]["Rotation"]
(extent_u, extent_v, extent_l) = get_total_extents(layer_extents)
logger.info("Reading coordinates...")
pop_min_extent = None
pop_max_extent = None
xcoords = []
ycoords = []
zcoords = []
cmap = cm.get_cmap("Dark2")
cmap_range = np.linspace(0, 1, num=len(populations))
colors = []
for pop_id, population in enumerate(populations):
coords = read_cell_attributes(
coords_path, population, namespace=coords_namespace
)
count = 0
cxcoords = []
cycoords = []
czcoords = []
for k, v in coords:
count += 1
cxcoords.append(v["X Coordinate"][0])
cycoords.append(v["Y Coordinate"][0])
czcoords.append(v["Z Coordinate"][0])
if subpopulation > -1 and count > subpopulation:
ridxs = np.random.choice(
np.arange(len(cxcoords)), replace=False, size=subpopulation
)
cxcoords = list(np.asarray(cxcoords)[ridxs])
cycoords = list(np.asarray(cycoords)[ridxs])
czcoords = list(np.asarray(czcoords)[ridxs])
colors += [cmap(cmap_range[pop_id]) for _ in range(len(cxcoords))]
xcoords += cxcoords
ycoords += cycoords
zcoords += czcoords
logger.info(f"Read {count} coordinates...")
pop_distribution = env.geometry["Cell Distribution"][population]
pop_layers = []
for layer in pop_distribution:
num_layer = pop_distribution[layer]
if num_layer > 0:
pop_layers.append(layer)
if pop_min_extent is None:
pop_min_extent = np.asarray(layer_extents[layer][0])
else:
pop_min_extent = np.minimum(
pop_min_extent, np.asarray(layer_extents[layer][0])
)
if pop_max_extent is None:
pop_max_extent = np.asarray(layer_extents[layer][1])
else:
pop_max_extent = np.maximum(
pop_min_extent, np.asarray(layer_extents[layer][1])
)
pts = np.concatenate(
(
np.asarray(xcoords).reshape(-1, 1),
np.asarray(ycoords).reshape(-1, 1),
np.asarray(zcoords).reshape(-1, 1),
),
axis=1,
)
if mayavi:
from mayavi import mlab
else:
import matplotlib.pyplot as plt
logger.info("Plotting coordinates...")
if mayavi:
mlab.points3d(*pts.T, color=(1, 1, 0), scale_factor=scale)
else:
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
ax.scatter(*pts.T, c=colors, s=int(scale))
logger.info("Constructing volume...")
from miv_simulator.volume import make_network_volume
if subvol:
subvol = make_network_volume(
(pop_min_extent[0], pop_max_extent[0]),
(pop_min_extent[1], pop_max_extent[1]),
(pop_min_extent[2], pop_max_extent[2]),
resolution=[3, 3, 3],
rotate=rotate,
)
else:
vol = make_network_volume(
(extent_u[0], extent_u[1]),
(extent_v[0], extent_v[1]),
(extent_l[0], extent_l[1]),
resolution=[3, 3, 3],
rotate=rotate,
)
logger.info("Plotting volume...")
if subvol:
if mayavi:
subvol.mplot_surface(color=(0, 0.4, 0), opacity=0.33)
else:
subvol.mplot_surface(color="k", alpha=0.33, figax=[fig, ax])
else:
if mayavi:
vol.mplot_surface(color=(0, 1, 0), opacity=0.33)
else:
vol.mplot_surface(color="k", alpha=0.33, figax=[fig, ax])
if mayavi:
mlab.show()
else:
ax.view_init(-90, 0)
plt.show()
return fig
def plot_cell_tree(
population,
gid,
tree_dict,
line_width=1.0,
sample=0.05,
color_edge_scalars=True,
mst=False,
conn_loc=True,
mayavi=False,
**kwargs,
):
import networkx as nx
fig_options = copy.copy(default_fig_options)
fig_options.update(kwargs)
xcoords = tree_dict["x"]
ycoords = tree_dict["y"]
zcoords = tree_dict["z"]
swc_type = tree_dict["swc_type"]
layer = tree_dict["layer"]
secnodes = tree_dict["section_topology"]["nodes"]
src = tree_dict["section_topology"]["src"]
dst = tree_dict["section_topology"]["dst"]
loc = tree_dict["section_topology"]["loc"]
x = xcoords.reshape(
-1,
)
y = ycoords.reshape(
-1,
)
z = zcoords.reshape(
-1,
)
edges = []
for sec, nodes in secnodes.items():
for i in range(1, len(nodes)):
srcnode = nodes[i - 1]
dstnode = nodes[i]
edges.append((srcnode, dstnode))
loc_x = []
loc_y = []
loc_z = []
for s, d, l in zip(src, dst, loc):
srcnode = secnodes[s][l]
dstnode = secnodes[d][0]
edges.append((srcnode, dstnode))
loc_x.append(x[srcnode])
loc_y.append(y[srcnode])
loc_z.append(z[srcnode])
conn_loc_x = np.asarray(loc_x, dtype=np.float64)
conn_loc_y = np.asarray(loc_y, dtype=np.float64)
conn_loc_z = np.asarray(loc_z, dtype=np.float64)
# Make a NetworkX graph out of our point and edge data
g = make_geometric_graph(x, y, z, edges)
edges = g.edges
# Compute minimum spanning tree using networkx
# nx.mst returns an edge generator
if mst:
edges = nx.minimum_spanning_tree(g).edges(data=True)
edge_array = np.array(list(edges)).T
start_idx = edge_array[0, :]
end_idx = edge_array[1, :]
start_idx = start_idx.astype(np.int_)
end_idx = end_idx.astype(np.int_)
if color_edge_scalars:
edge_scalars = z[start_idx]
edge_color = None
else:
edge_scalars = None
edge_color = hex2rgb(rainbow_colors[gid % len(rainbow_colors)])
if mayavi:
from mayavi import mlab
mlab.figure(bgcolor=(0, 0, 0))
fig = mlab.gcf()
# Plot this with Mayavi
g = plot_graph(
x,
y,
z,
start_idx,
end_idx,
edge_scalars=edge_scalars,
edge_color=edge_color,
opacity=0.8,
colormap="summer",
line_width=line_width,
figure=fig,
)
if conn_loc:
conn_pts = mlab.points3d(
conn_loc_x,
conn_loc_y,
conn_loc_z,
figure=fig,
mode="2dcross",
colormap="copper",
scale_factor=10,
)
fig.scene.x_plus_view()
if fig_options.saveFig:
mlab.savefig(
f"{population}_{gid}_cell_tree.x3d",
figure=fig,
magnification=10,
)
if fig_options.showFig:
mlab.show()
else:
fig = plt.figure(figsize=fig_options.figSize)
ax = fig.add_subplot(projection="3d")
layer_set = set(layer)
sct = ax.scatter(
x,
y,
zs=z,
c=layer,
alpha=0.7,
)
# produce a legend with the unique colors from the scatter
legend_elements = sct.legend_elements()
layer_legend = ax.legend(
*legend_elements, loc="upper right", title="Layer"
)
ax.add_artist(layer_legend)
for i, j in g.edges:
e_x = (x[i], x[j])
e_y = (y[i], y[j])
e_z = (z[i], z[j])
ax.plot(e_x, e_y, e_z, c="black", alpha=0.5)
ax.view_init(30)
ax.set_axis_off
if fig_options.saveFig:
if isinstance(fig_options.saveFig, str):
filename = fig_options.saveFig
else:
filename = (
f"{population}_{gid}_cell_tree.{fig_options.figFormat}"
)
plt.savefig(filename)
print(f"Save figure: {filename}")
if fig_options.showFig:
show_figure()
# return fig
## Plot spike raster
[docs]def plot_spike_raster(
input_path,
namespace_id,
include=["eachPop"],
time_range=None,
time_variable="t",
max_spikes=int(1e6),
labels="legend",
pop_rates=True,
spike_hist=None,
spike_hist_bin=5,
include_artificial=True,
marker=".",
**kwargs,
):
"""
Raster plot of network spike times. Returns the figure handle.
input_path: file with spike data
namespace_id: attribute namespace for spike events
time_range ([start:stop]): Time range of spikes shown; if None shows all (default: None)
time_variable: Name of variable containing spike times (default: 't')
max_spikes (int): maximum number of spikes that will be plotted (default: 1e6)
labels = ('legend', 'overlay'): Show population labels in a legend or overlayed on one side of raster (default: 'legend')
pop_rates = (True|False): Include population rates (default: False)
spike_hist (None|'overlay'|'subplot'): overlay line over raster showing spike histogram (spikes/bin) (default: False)
spike_hist_bin (int): Size of bin in ms to use for histogram (default: 5)
marker (char): Marker for each spike (default: '|')
"""
fig_options = copy.copy(default_fig_options)
fig_options.update(kwargs)
mpl.rcParams["font.size"] = fig_options.fontSize
(population_ranges, N) = read_population_ranges(input_path)
population_names = read_population_names(input_path)
total_num_cells = 0
pop_num_cells = {}
pop_start_inds = {}
for k in population_names:
pop_start_inds[k] = population_ranges[k][0]
pop_num_cells[k] = population_ranges[k][1]
total_num_cells += population_ranges[k][1]
include = list(include)
# Replace 'eachPop' with list of populations
if "eachPop" in include:
include.remove("eachPop")
for pop in population_names:
include.append(pop)
# sort according to start index
include.sort(key=lambda x: pop_start_inds[x])
spkdata = spikedata.read_spike_events(
input_path,
include,
namespace_id,
include_artificial=include_artificial,
spike_train_attr_name=time_variable,
time_range=time_range,
)
spkpoplst = spkdata["spkpoplst"]
spkindlst = spkdata["spkindlst"]
spktlst = spkdata["spktlst"]
num_cell_spks = spkdata["num_cell_spks"]
pop_active_cells = spkdata["pop_active_cells"]
tmin = spkdata["tmin"]
tmax = spkdata["tmax"]
fraction_active = {
pop_name: float(len(pop_active_cells[pop_name]))
/ float(pop_num_cells[pop_name])
for pop_name in include
}
time_range = [tmin, tmax]
# Calculate spike histogram if requested
if spike_hist:
all_spkts = []
sphist_x = None
sphist_y = None
bin_edges = None
if len(spktlst) > 0:
all_spkts = np.concatenate(
[np.concatenate(lst, axis=0) for lst in spktlst]
)
sphist_y, bin_edges = np.histogram(
all_spkts,
bins=np.arange(time_range[0], time_range[1], spike_hist_bin),
)
sphist_x = bin_edges[:-1] + (spike_hist_bin / 2)
else:
spike_hist = None
maxN = 0
minN = N
avg_rates = {}
tsecs = (time_range[1] - time_range[0]) / 1e3
for i, pop_name in enumerate(spkpoplst):
pop_num = len(pop_active_cells[pop_name])
maxN = max(maxN, max(pop_active_cells[pop_name]))
minN = min(minN, min(pop_active_cells[pop_name]))
if pop_num > 0:
if num_cell_spks[pop_name] == 0:
avg_rates[pop_name] = 0
else:
avg_rates[pop_name] = (
num_cell_spks[pop_name] / pop_num
) / tsecs
pop_colors = {
pop_name: dflt_colors[ipop % len(raster_colors)]
for ipop, pop_name in enumerate(spkpoplst)
}
pop_spk_dict = {
pop_name: (pop_spkinds, pop_spkts)
for (pop_name, pop_spkinds, pop_spkts) in zip(
spkpoplst, spkindlst, spktlst
)
}
n_subplots = 1
if spike_hist is None:
n_subplots = max(len(spkpoplst), 1)
fig, axes = plt.subplots(
nrows=n_subplots, sharex=True, figsize=fig_options.figSize
)
elif spike_hist == "subplot":
n_subplots = max(len(spkpoplst), 1) + 1
fig, axes = plt.subplots(
nrows=n_subplots,
sharex=True,
figsize=fig_options.figSize,
gridspec_kw={"height_ratios": [1] * len(spkpoplst) + [2]},
)
fig.suptitle("Spike Raster", fontsize=fig_options.fontSize)
sctplots = []
if n_subplots == 1:
axes = [axes]
for i, pop_name in enumerate(spkpoplst):
if pop_name not in pop_spk_dict:
continue
pop_spkinds, pop_spkts = pop_spk_dict[pop_name]
if max_spikes is not None:
if int(max_spikes) < len(pop_spkinds):
logger.info(
f"Loading only randomly sampled {max_spikes} out of {len(pop_spkts)} spikes for population {pop_name}"
)
sample_inds = np.random.randint(
0, len(pop_spkinds) - 1, size=int(max_spikes)
)
pop_spkts = pop_spkts[sample_inds]
pop_spkinds = pop_spkinds[sample_inds]
sct = None
if len(pop_spkinds) > 0:
for this_pop_spkts, this_pop_spkinds in zip(pop_spkts, pop_spkinds):
sct = axes[i].scatter(
this_pop_spkts,
this_pop_spkinds,
s=1,
linewidths=fig_options.lw,
marker=marker,
c=pop_colors[pop_name],
alpha=0.5,
label=pop_name,
)
axes[i].spines["top"].set_visible(False)
axes[i].spines["bottom"].set_visible(False)
axes[i].spines["left"].set_visible(False)
axes[i].spines["right"].set_visible(False)
sctplots.append(sct)
N = pop_num_cells[pop_name]
S = pop_start_inds[pop_name]
axes[i].set_ylim(S, S + N - 1)
lgd_info = [
(
100.0 * fraction_active.get(pop_name, 0.0),
avg_rates.get(pop_name, 0.0),
)
for pop_name in spkpoplst
]
# set raster plot y tick labels to the middle of the index range for each population
for pop_name, a in zip_longest(spkpoplst, fig.axes):
if pop_name not in pop_active_cells:
continue
if a is None:
continue
if len(pop_active_cells[pop_name]) > 0:
maxN = max(pop_active_cells[pop_name])
minN = min(pop_active_cells[pop_name])
loc = pop_start_inds[pop_name] + 0.5 * (maxN - minN)
yaxis = a.get_yaxis()
yaxis.set_ticks([loc])
yaxis.set_ticklabels([pop_name])
yaxis.set_tick_params(length=0)
a.get_xaxis().set_tick_params(length=0)
if spike_hist:
# Plot spike histogram
pch = interpolate.pchip(sphist_x, sphist_y)
res_npts = int(sphist_x.max() - sphist_x.min())
sphist_x_res = np.linspace(
sphist_x.min(), sphist_x.max(), res_npts, endpoint=True
)
sphist_y_res = pch(sphist_x_res)
if spike_hist == "overlay":
ax2 = axes[-1].twinx()
ax2.plot(sphist_x_res, sphist_y_res, linewidth=0.5)
ax2.set_ylabel(
"Spike count", fontsize=fig_options.fontSize
) # add yaxis label in opposite side
ax2.set_xlim(time_range)
elif spike_hist == "subplot":
ax2 = axes[-1]
ax2.bar(sphist_x_res, sphist_y_res, linewidth=1.0)
ax2.set_xlabel("Time (ms)", fontsize=fig_options.fontSize)
ax2.set_ylabel("Spikes", fontsize=fig_options.fontSize)
ax2.set_xlim(time_range)
# locator=MaxNLocator(prune='both', nbins=10)
# ax2.xaxis.set_major_locator(locator)
if labels == "legend":
# Shrink axes by 15%
if n_subplots > 1:
for ax in axes:
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.85, box.height])
if pop_rates:
lgd_labels = [
f"{pop_name} ({info[0]:.02f}% active; {info[1]:.3g} Hz)"
for pop_name, info in zip_longest(spkpoplst, lgd_info)
]
else:
lgd_labels = [
f"{pop_name} ({info[0]:.02f}% active)"
for pop_name, info in zip_longest(spkpoplst, lgd_info)
]
# Add legend
lgd = fig.legend(
sctplots,
lgd_labels,
loc="center right",
fontsize="small",
scatterpoints=1,
markerscale=5.0,
bbox_to_anchor=(1.002, 0.5),
bbox_transform=plt.gcf().transFigure,
)
fig.artists.append(lgd)
elif labels == "overlay":
if pop_rates:
lgd_labels = [
f"{pop_name} ({info[0]:.02f}% active; {info[1]:.3g} Hz)"
for pop_name, info in zip_longest(spkpoplst, lgd_info)
]
else:
lgd_labels = [
f"{pop_name} ({info[0]:.02f}% active)"
for pop_name, info in zip_longest(spkpoplst, lgd_info)
]
for i, lgd_label in enumerate(lgd_labels):
at = AnchoredText(
pop_name + " " + lgd_label,
loc="upper right",
borderpad=0.01,
prop=dict(size=fig_options.fontSize),
)
axes[i].add_artist(at)
max_label_len = max(len(l) for l in lgd_labels)
elif labels == "yticks":
for pop_name, info, a in zip_longest(spkpoplst, lgd_info, fig.axes):
if a is None or info is None:
continue
if pop_rates:
label = f"{info[0]:.02f}%\n{info[1]:.2g} Hz"
else:
label = f"{info[0]:.02f}%\n"
maxN = max(pop_active_cells[pop_name])
minN = min(pop_active_cells[pop_name])
loc = pop_start_inds[pop_name] + 0.5 * (maxN - minN)
a.set_yticks([loc, loc])
a.set_yticklabels([pop_name, label])
yticklabels = a.get_yticklabels()
# Create offset transform in x direction
dx = -80 / 72.0
dy = 0 / 72.0
offset = mpl.transforms.ScaledTranslation(
dx, dy, fig.dpi_scale_trans
)
# apply offset transform to labels.
yticklabels[0].set_transform(
yticklabels[0].get_transform() + offset
)
dx = -80 / 72.0
dy = 0 / 72.0
offset = mpl.transforms.ScaledTranslation(
dx, dy, fig.dpi_scale_trans
)
yticklabels[1].set_ha("left")
yticklabels[1].set_transform(
yticklabels[1].get_transform() + offset
)
plt.subplots_adjust(wspace=0.2, hspace=0.2)
# save figure
if fig_options.saveFig:
if isinstance(fig_options.saveFig, str):
filename = fig_options.saveFig
else:
filename = f"{namespace_id} raster.{fig_options.figFormat}"
plt.savefig(filename)
# show fig
if fig_options.showFig:
show_figure()
return fig
[docs]def plot_spike_histogram(
input_path: str,
namespace_id,
config_path: Optional[str] = None,
include: List[str] = ["eachPop"], # TODO: Probably not safe
time_variable: str = "t",
time_range: Optional[Tuple[int, int]] = None,
pop_rates: bool = False,
bin_size: int = 5,
smooth: float = 0,
quantity: str = "rate",
include_artificial: bool = True,
progress: bool = False,
overlay: bool = True,
graph_type: str = "bar",
**kwargs,
):
"""
Plots spike histogram. Returns figure handle.
Parameters
----------
input_path : str
file with spike data
namespace_id :
attribute namespace for spike events
config_path : Optional[str]
config_path
include : List[str] (["eachPop", <population name>])
List of data series to include. (default: ["eachPop"] expands to the name of each population)
time_variable : str
Name of variable containing spike times (default: "t")
time_range : Optional[Tuple[int,int]] ([start:stop])
Time range of spikes shown. If None shows all. (default: None)
pop_rates : bool
pop_rates
bin_size : int
Size in ms of each bin (default: 5)
smooth : float
smooth
quantity : str ("rate", "count")
Quantity of y axis (firing rate in Hz, or spike count) (default: "rate")
include_artificial : bool
include_artificial
progress : bool
progress
overlay : bool
Whether to overlay the data lines or plot in separate subplots (default: True)
graph_type : str ("line", "bar")
Type of graph to use (line graph or bar plot) (default: "line")
"""
fig_options = copy.copy(default_fig_options)
fig_options.update(kwargs)
baks_config = copy.copy(kwargs)
env = None
if config_path is not None:
env = Env(config=config_path)
if env.analysis_config is not None:
baks_config.update(env.analysis_config["Firing Rate Inference"])
(population_ranges, N) = read_population_ranges(input_path)
population_names = read_population_names(input_path)
pop_num_cells = {}
for k in population_names:
pop_num_cells[k] = population_ranges[k][1]
# Replace 'eachPop' with list of populations
if "eachPop" in include:
include.remove("eachPop")
for pop in population_names:
include.append(pop)
include.reverse()
spkdata = spikedata.read_spike_events(
input_path,
include,
namespace_id,
spike_train_attr_name=time_variable,
time_range=time_range,
include_artificial=include_artificial,
)
spkpoplst = spkdata["spkpoplst"]
spkindlst = spkdata["spkindlst"]
spktlst = spkdata["spktlst"]
num_cell_spks = spkdata["num_cell_spks"]
pop_active_cells = spkdata["pop_active_cells"]
tmin = spkdata["tmin"]
tmax = spkdata["tmax"]
time_range = [tmin, tmax]
avg_rates = {}
maxN = 0
minN = N
if pop_rates:
tsecs = (time_range[1] - time_range[0]) / 1e3
for i, pop_name in enumerate(spkpoplst):
pop_num = len(pop_active_cells[pop_name])
maxN = max(maxN, max(pop_active_cells[pop_name]))
minN = min(minN, min(pop_active_cells[pop_name]))
if pop_num > 0:
if num_cell_spks[pop_name] == 0:
avg_rates[pop_name] = 0
else:
avg_rates[pop_name] = (
num_cell_spks[pop_name] / pop_num
) / tsecs
# Y-axis label
if quantity == "rate":
yaxisLabel = "Mean cell firing rate (Hz)"
elif quantity == "count":
yaxisLabel = "Spike count"
elif quantity == "active":
yaxisLabel = "Active cell count"
else:
logger.error(f"Invalid quantity value {quantity}")
return
# create fig
fig, axes = plt.subplots(
len(spkpoplst), 1, figsize=fig_options.figSize, sharex=True
)
time_bins = np.arange(time_range[0], time_range[1], bin_size)
hist_dict = {}
if quantity == "rate":
for subset, spkinds, spkts in zip(spkpoplst, spkindlst, spktlst):
spkdict = spikedata.make_spike_dict(spkinds, spkts)
sdf_dict = spikedata.spike_density_estimate(
subset, spkdict, time_bins, progress=progress, **baks_config
)
bin_dict = defaultdict(lambda: {"rates": 0.0, "active": 0})
for ind, dct in sdf_dict.items():
rate = dct["rate"]
for ibin in range(0, len(time_bins)):
d = bin_dict[ibin]
bin_rate = rate[ibin]
d["rates"] += bin_rate
d["active"] += 1
hist_dict[subset] = bin_dict
logger.info(
"Calculated spike rates for %i cells in population %s"
% (len(sdf_dict), subset)
)
else:
for subset, spkinds, spkts in zip(spkpoplst, spkindlst, spktlst):
spkdict = spikedata.make_spike_dict(spkinds, spkts)
count_bin_dict = spikedata.spike_bin_counts(spkdict, time_bins)
bin_dict = defaultdict(lambda: {"counts": 0, "active": 0})
for ind, counts in count_bin_dict.items():
for ibin in range(0, len(time_bins) - 1):
d = bin_dict[ibin]
d["counts"] += counts[ibin]
d["active"] += 1
hist_dict[subset] = bin_dict
logger.info(
"Calculated spike counts for %i cells in population %s"
% (len(count_bin_dict), subset)
)
del spkindlst, spktlst
# Plot separate line for each entry in include
for iplot, subset in enumerate(spkpoplst):
hist_x = time_bins + (bin_size / 2)
bin_dict = hist_dict[subset]
if quantity == "rate":
hist_y = np.asarray(
[
(bin_dict[ibin]["rates"] / bin_dict[ibin]["active"])
if bin_dict[ibin]["active"] > 0
else 0.0
for ibin in range(0, len(time_bins))
]
)
elif quantity == "active":
hist_y = np.asarray(
[bin_dict[ibin]["active"] for ibin in range(0, len(time_bins))]
)
else:
hist_y = np.asarray(
[bin_dict[ibin]["counts"] for ibin in range(0, len(time_bins))]
)
del bin_dict
del hist_dict[subset]
color = dflt_colors[iplot % len(dflt_colors)]
if pop_rates:
label = str(subset) + " (%i active; %.3g Hz)" % (
len(pop_active_cells[subset]),
avg_rates[subset],
)
else:
label = str(subset) + f" ({len(pop_active_cells[subset])} active)"
ax = plt.subplot(len(spkpoplst), 1, (iplot + 1))
plt.title(label, fontsize=fig_options.fontSize)
ax.tick_params(labelsize=fig_options.fontSize)
if iplot < len(spkpoplst) - 1:
ax.xaxis.set_visible(False)
if smooth:
hsignal = signal.savgol_filter(
hist_y,
window_length=2 * (len(hist_y) / 16) + 1,
polyorder=smooth,
)
else:
hsignal = hist_y
if graph_type == "line":
ax.plot(hist_x, hsignal, linewidth=fig_options.lw, color=color)
elif graph_type == "bar":
ax.bar(
hist_x,
hsignal,
width=bin_size,
color=color,
edgecolor="black",
alpha=0.85,
)
if iplot == 0:
ax.set_ylabel(yaxisLabel, fontsize=fig_options.fontSize)
if iplot == len(spkpoplst) - 1:
ax.set_xlabel("Time (ms)", fontsize=fig_options.fontSize)
else:
ax.tick_params(labelbottom="off")
ax.set_xlim(time_range)
plt.tight_layout()
# Add legend
if overlay:
for i, subset in enumerate(spkpoplst):
plt.plot(
0, 0, color=dflt_colors[i % len(dflt_colors)], label=str(subset)
)
plt.legend(
fontsize=fig_options.fontSize,
bbox_to_anchor=(1.04, 1),
loc=2,
borderaxespad=0.0,
)
maxLabelLen = min(10, max(len(str(l)) for l in include))
plt.subplots_adjust(right=(0.9 - 0.012 * maxLabelLen))
if fig_options.saveFig:
if isinstance(fig_options.saveFig, str):
filename = fig_options.saveFig
else:
filename = namespace_id + " " + f"histogram.{fig_options.figFormat}"
plt.savefig(filename)
if fig_options.showFig:
show_figure()
return fig
[docs]def plot_lfp(
input_path,
config_path=None,
time_range=None,
compute_psd=False,
window_size=4096,
frequency_range=(0, 400.0),
overlap=0.9,
bandpass_filter=False,
dt=None,
**kwargs,
):
"""
Line plot of LFP state variable (default: v). Returns figure handle.
config: path to model configuration file
input_path: file with LFP trace data
time_range ([start:stop]): Time range of spikes shown; if None shows all (default: None)
"""
fig_options = copy.copy(default_fig_options)
fig_options.update(kwargs)
env = None
if config_path is not None:
env = Env(config=config_path)
nrows = 1
if env is not None:
nrows = len(env.LFP_config)
ncols = 1
psd_col = 1
if compute_psd:
ncols += 1
gs = gridspec.GridSpec(
nrows, ncols, width_ratios=[3, 1] if ncols > 1 else [1]
)
fig = plt.figure(figsize=fig_options.figSize)
if env is None:
lfp_array = np.loadtxt(
input_path, dtype=np.dtype([("t", np.float32), ("v", np.float32)])
)
if time_range is None:
t = lfp_array["t"]
v = lfp_array["v"]
else:
tlst = []
vlst = []
for t, v in zip(lfp_array["t"], lfp_array["v"]):
if time_range[0] <= t <= time_range[1]:
tlst.append(t)
vlst.append(v)
t = np.asarray(tlst)
v = np.asarray(vlst)
if dt is None:
raise RuntimeError(
"plot_lfp: dt must be provided when config_path is None"
)
Fs = 1000.0 / dt
if compute_psd:
psd, freqs, peak_index = signal_psd(
v,
frequency_range=frequency_range,
Fs=Fs,
window_size=window_size,
overlap=overlap,
)
filtered_v = None
if bandpass_filter:
filtered_v = apply_filter(
v,
butter_bandpass_filter(
max(frequency_range[0], 1.0),
frequency_range[1],
Fs,
order=2,
),
)
iplot = 0
ax = plt.subplot(gs[iplot, 0])
ax.set_title("LFP", fontsize=fig_options.fontSize)
ax.plot(t, v, linewidth=fig_options.lw)
ax.set_xlabel("Time (ms)", fontsize=fig_options.fontSize)
ax.set_ylabel("Field Potential (mV)", fontsize=fig_options.fontSize)
if bandpass_filter:
if filtered_v is not None:
ax.plot(
t,
filtered_v,
label="Filtered LFP",
color="red",
linewidth=fig_options.lw,
)
if compute_psd:
ax = plt.subplot(gs[iplot, psd_col])
ax.plot(freqs, psd, linewidth=fig_options.lw)
ax.set_xlabel("Frequency (Hz)", fontsize=fig_options.fontSize)
ax.set_ylabel(
"Power Spectral Density (dB/Hz)", fontsize=fig_options.fontSize
)
ax.set_title(
f"PSD (peak: {freqs[peak_index]:.3f} Hz)",
fontsize=fig_options.fontSize,
)
if fig_options.saveFig:
if isinstance(fig_options.saveFig, str):
filename = fig_options.saveFig
else:
filename = f"MIV LFP.{fig_options.figFormat}"
plt.savefig(filename)
# show fig
if fig_options.showFig:
show_figure()
else:
for iplot, (lfp_label, lfp_config_dict) in enumerate(
env.LFP_config.items()
):
namespace_id = f"Local Field Potential {str(lfp_label)}"
import h5py
infile = h5py.File(input_path)
logger.info(f"plot_lfp: reading data for {namespace_id}...")
if time_range is None:
t = infile[namespace_id]["t"]
v = infile[namespace_id]["v"]
t = np.asarray(t)
v = np.asarray(v)
else:
tlst = []
vlst = []
for t, v in zip(
infile[namespace_id]["t"], infile[namespace_id]["v"]
):
if time_range[0] <= t <= time_range[1]:
tlst.append(t)
vlst.append(v)
t = np.asarray(tlst)
v = np.asarray(vlst)
dt = lfp_config_dict["dt"]
Fs = 1000.0 / dt
if compute_psd:
psd, freqs, peak_index = signal_psd(
v,
Fs=Fs,
frequency_range=frequency_range,
window_size=window_size,
overlap=overlap,
)
filtered_v = None
if bandpass_filter:
filtered_v = apply_filter(
v,
butter_bandpass_filter(
max(frequency_range[0], 1.0),
frequency_range[1],
Fs,
order=2,
),
)
ax = plt.subplot(gs[iplot, 0])
ax.set_title(f"{namespace_id}", fontsize=fig_options.fontSize)
ax.plot(t, v, label=lfp_label, linewidth=fig_options.lw)
ax.set_xlabel("Time (ms)", fontsize=fig_options.fontSize)
ax.set_ylabel("Field Potential (mV)", fontsize=fig_options.fontSize)
if bandpass_filter:
if filtered_v is not None:
ax.plot(
t,
filtered_v,
label=f"{lfp_label} (filtered)",
color="red",
linewidth=fig_options.lw,
)
if compute_psd:
ax = plt.subplot(gs[iplot, psd_col])
ax.plot(freqs, psd, linewidth=fig_options.lw)
ax.set_xlabel("Frequency (Hz)", fontsize=fig_options.fontSize)
ax.set_ylabel(
"Power Spectral Density (dB/Hz)",
fontsize=fig_options.fontSize,
)
ax.set_title(
f"PSD (peak: {freqs[peak_index]:.3f} Hz)",
fontsize=fig_options.fontSize,
)
# save figure
if fig_options.saveFig:
if isinstance(fig_options.saveFig, str):
filename = fig_options.saveFig
else:
filename = namespace_id + f".{fig_options.figFormat}"
plt.savefig(filename)
# show fig
if fig_options.showFig:
show_figure()
return fig
[docs]def plot_lfp_spectrogram(
input_path,
config_path=None,
time_range=None,
window_size=4096,
overlap=0.9,
frequency_range=(0, 400.0),
dt=None,
**kwargs,
):
"""
Line plot of LFP power spectrogram. Returns figure handle.
config: path to model configuration file
input_path: file with LFP trace data
time_range ([start:stop]): Time range of spikes shown; if None shows all (default: None)
"""
fig_options = copy.copy(default_fig_options)
fig_options.update(kwargs)
mpl.rcParams["font.size"] = fig_options.fontSize
env = None
if config_path is not None:
env = Env(config=config_path)
nrows = 1
if env is not None:
nrows = len(env.LFP_config)
ncols = 1
gs = gridspec.GridSpec(
nrows, ncols, width_ratios=[3, 1] if ncols > 1 else [1]
)
fig = plt.figure(figsize=fig_options.figSize)
if env is None:
lfp_array = np.loadtxt(
input_path, dtype=np.dtype([("t", np.float32), ("v", np.float32)])
)
if time_range is None:
t = lfp_array["t"]
v = lfp_array["v"]
else:
tlst = []
vlst = []
for t, v in zip(lfp_array["t"], lfp_array["v"]):
if time_range[0] <= t <= time_range[1]:
tlst.append(t)
vlst.append(v)
t = np.asarray(tlst)
v = np.asarray(vlst)
if dt is None:
raise RuntimeError(
"plot_lfp_spectrogram: dt must be provided when config_path is None"
)
Fs = int(1000.0 / dt)
freqs, t, Sxx = signal_power_spectrogram(v, Fs, window_size, overlap)
freqinds = np.where(
(freqs >= frequency_range[0]) & (freqs <= frequency_range[1])
)
freqs = freqs[freqinds]
sxx = Sxx[freqinds, :][0]
iplot = 0
ax = plt.subplot(gs[iplot, 0])
ax.set_xlim([0.4, 0.8])
ax.set_ylim(*frequency_range)
ax.set_title("LFP Spectrogram", fontsize=fig_options.fontSize)
pcm = ax.pcolormesh(t, freqs, sxx, cmap="jet")
ax.set_xlabel("Time (s)", fontsize=fig_options.fontSize)
ax.set_ylabel("Frequency (Hz)", fontsize=fig_options.fontSize)
ax.tick_params(axis="both", labelsize=fig_options.fontSize)
fig.colorbar(pcm, ax=ax)
# save figure
if fig_options.saveFig:
if isinstance(fig_options.saveFig, str):
filename = fig_options.saveFig
else:
filename = namespace_id + f".{fig_options.figFormat}"
plt.savefig(filename)
# show fig
if fig_options.showFig:
show_figure()
else:
for iplot, (lfp_label, lfp_config_dict) in enumerate(
env.LFP_config.items()
):
namespace_id = f"Local Field Potential {str(lfp_label)}"
import h5py
infile = h5py.File(input_path)
logger.info(f"plot_lfp: reading data for {namespace_id}...")
if time_range is None:
t = infile[namespace_id]["t"]
v = infile[namespace_id]["v"]
else:
tlst = []
vlst = []
for t, v in zip(
infile[namespace_id]["t"], infile[namespace_id]["v"]
):
if time_range[0] <= t <= time_range[1]:
tlst.append(t)
vlst.append(v)
t = np.asarray(tlst)
v = np.asarray(vlst)
dt = lfp_config_dict["dt"]
Fs = int(1000.0 / dt)
freqs, t, Sxx = signal_power_spectrogram(
v, Fs, window_size, overlap
)
freqinds = np.where(
(freqs >= frequency_range[0]) & (freqs <= frequency_range[1])
)
freqs = freqs[freqinds]
sxx = Sxx[freqinds, :][0]
ax = plt.subplot(gs[iplot, 0])
ax.set_ylim(*frequency_range)
ax.set_title(f"{namespace_id}", fontsize=fig_options.fontSize)
ax.pcolormesh(t, freqs, sxx, cmap="jet")
ax.set_xlabel("Time (s)", fontsize=fig_options.fontSize)
ax.set_ylabel("Frequency (Hz)", fontsize=fig_options.fontSize)
# save figure
if fig_options.saveFig:
if isinstance(fig_options.saveFig, str):
filename = fig_options.saveFig
else:
filename = namespace_id + f".{fig_options.figFormat}"
plt.savefig(filename)
# show fig
if fig_options.showFig:
show_figure()
return fig
## Plot biophys cell tree
[docs]def plot_biophys_cell_tree(
env,
biophys_cell,
node_filters={"swc_types": ["apical", "basal"]},
plot_synapses=False,
synapse_filters=None,
syn_source_threshold=0.0,
line_width=8.0,
plot_method="neuron",
**kwargs,
):
"""
Plot cell morphology and optionally synapse locations.
"""
fig_options = copy.copy(default_fig_options)
fig_options.update(kwargs)
morph_graph = cells.make_morph_graph(
biophys_cell, node_filters=node_filters
)
gid = biophys_cell.gid
population = biophys_cell.population_name
# Obtain synapse xyz locations
syn_attrs = env.synapse_attributes
synapse_filters = synapses.get_syn_filter_dict(
env, synapse_filters, convert=True
)
syns_dict = syn_attrs.filter_synapses(biophys_cell.gid, **synapse_filters)
syn_sec_dict = defaultdict(list)
if (syn_source_threshold is not None) and (syn_source_threshold > 0.0):
syn_source_count = defaultdict(int)
for syn_id, syn in syns_dict.items():
syn_source_count[syn.source.gid] += 1
syn_source_max = 0
syn_source_pctile = {}
for source_id, source_id_count in syn_source_count.items():
syn_source_max = max(syn_source_max, source_id_count)
logger.info("synapse source max count is %d" % (syn_source_max))
for syn_id, syn in syns_dict.items():
count = syn_source_count[syn.source.gid]
syn_source_pctile[syn_id] = float(count) / float(syn_source_max)
syns_dict = {
syn_id: syn
for syn_id, syn in syns_dict.items()
if syn_source_pctile[syn_id] >= syn_source_threshold
}
for syn_id, syn in syns_dict.items():
syn_sec_dict[syn.syn_section].append(syn)
syn_xyz_sec_dict = {}
syn_src_sec_dict = {}
for sec_id, syns in syn_sec_dict.items():
sec = biophys_cell.hoc_cell.sections[sec_id]
syn_locs = [syn.syn_loc for syn in syns]
ip_x, ip_y, ip_z, ip_diam = interplocs(sec)
syn_xyz_sec_dict[sec_id] = np.column_stack(
(ip_x(syn_locs), ip_y(syn_locs), ip_z(syn_locs))
)
syn_sources = [syn.source.gid for syn in syns]
syn_src_sec_dict[sec_id] = np.asarray(syn_sources)
fig = None
if plot_method == "mayavi":
from mayavi import mlab
colormap = kwargs.get("colormap", "coolwarm")
mlab.figure(bgcolor=kwargs.get("bgcolor", (0, 0, 0)))
xcoords = np.asarray(
[x for (i, x) in morph_graph.nodes.data("x")], dtype=np.float32
)
ycoords = np.asarray(
[y for (i, y) in morph_graph.nodes.data("y")], dtype=np.float32
)
zcoords = np.asarray(
[z for (i, z) in morph_graph.nodes.data("z")], dtype=np.float32
)
layer = np.asarray(
[layer for (i, layer) in morph_graph.nodes.data("layer")],
dtype=np.int32,
)
# edges = nx.minimum_spanning_tree(morph_graph).edges(data=True)
edges = morph_graph.edges(data=True)
start_idx, end_idx, _ = np.array(list(edges)).T
start_idx = start_idx.astype(np.int)
end_idx = end_idx.astype(np.int)
# edge_scalars = layers[start_idx]
logger.info(f"plotting tree {biophys_cell.gid}")
# Plot morphology graph with Mayavi
plot_graph(
xcoords,
ycoords,
zcoords,
start_idx,
end_idx,
edge_color=(1, 1, 1),
opacity=0.6,
line_width=line_width,
)
logger.info(f"plotting {len(syns_dict)} synapses")
for sec_id, syn_xyz in syn_xyz_sec_dict.items():
syn_sources = syn_src_sec_dict[sec_id]
if None in syn_sources:
mlab.points3d(
syn_xyz[:, 0],
syn_xyz[:, 1],
syn_xyz[:, 2],
scale_mode="vector",
colormap=colormap,
scale_factor=10.0,
color=(1, 0, 0),
)
else:
mlab.points3d(
syn_xyz[:, 0],
syn_xyz[:, 1],
syn_xyz[:, 2],
syn_sources,
scale_mode="vector",
colormap=colormap,
scale_factor=10.0,
color=(1, 0, 0),
)
mlab.gcf().scene.x_plus_view()
mlab.show()
fig = mlab.gcf()
elif plot_method == "matplotlib":
fig = plt.figure(figsize=fig_options.figSize)
ax = fig.add_subplot(projection="3d")
xcoords = np.asarray(
[x for (i, x) in morph_graph.nodes.data("x")], dtype=np.float32
)
ycoords = np.asarray(
[y for (i, y) in morph_graph.nodes.data("y")], dtype=np.float32
)
zcoords = np.asarray(
[z for (i, z) in morph_graph.nodes.data("z")], dtype=np.float32
)
layer = np.asarray(
[layer for (i, layer) in morph_graph.nodes.data("layer")],
dtype=np.int32,
)
sct = ax.scatter(
xcoords,
ycoords,
zcoords,
c=layer,
alpha=0.7,
)
# produce a legend with the unique colors from the scatter
legend_elements = sct.legend_elements()
layer_legend = ax.legend(
*legend_elements, loc="upper right", title="Layer"
)
ax.add_artist(layer_legend)
for i, j in morph_graph.edges:
e_x = (xcoords[i], xcoords[j])
e_y = (ycoords[i], ycoords[j])
e_z = (zcoords[i], zcoords[j])
ax.plot(e_x, e_y, e_z, c="black", alpha=0.5)
for sec_id, syn_xyz in syn_xyz_sec_dict.items():
syn_sources = syn_src_sec_dict[sec_id]
if None in syn_sources:
ax.scatter(
syn_xyz[:, 0],
syn_xyz[:, 1],
syn_xyz[:, 2],
marker="^",
s=100,
)
else:
ax.scatter(
syn_xyz[:, 0],
syn_xyz[:, 1],
syn_xyz[:, 2],
c=syn_sources,
marker="o",
)
ax.view_init(30)
ax.set_axis_off
# if fig_options.saveFig:
# if isinstance(fig_options.saveFig, str):
# filename = fig_options.saveFig
# else:
# filename = (
# f"{population}_{gid}_cell_tree.{fig_options.figFormat}"
# )
# plt.savefig(filename)
if fig_options.showFig:
# show_figure()
plt.show()
else:
sl = h.SectionList([sec for sec in biophys_cell.hoc_cell.all])
for sec in sl:
sec.v = 0
h.topology()
h.psection(list(sl)[0])
ps = h.PlotShape(
sl, False
) # False tells h.PlotShape not to use NEURON's gui
ax = ps.plot(plt)
plt.show()
# return fig
# =============================================================================
# Get radially averaged PSD of 2D PSD (total power spectrum by angular bin)
# =============================================================================
def get_RPSD(psd2D, dTheta=30, rMin=10, rMax=100):
h = psd2D.shape[0]
w = psd2D.shape[1]
wc = w // 2
hc = h // 2
# note that displaying PSD as image inverts Y axis
# create an array of integer angular slices of dTheta
Y, X = np.ogrid[0:h, 0:w]
theta = np.rad2deg(np.arctan2(-(Y - hc), (X - wc)))
theta = np.mod(theta + dTheta / 2 + 360, 360)
theta = dTheta * (theta // dTheta)
theta = theta.astype(np.int)
# mask below rMin and above rMax by setting to -100
R = np.hypot(-(Y - hc), (X - wc))
mask = np.logical_and(R > rMin, R < rMax)
theta = theta + 100
theta = np.multiply(mask, theta)
theta = theta - 100
# SUM all psd2D pixels with label 'theta' for 0<=theta❤60 between rMin and rMax
angF = np.arange(0, 360, int(dTheta))
psd1D = ndimage.sum(psd2D, theta, index=angF)
# normalize each sector to the total sector power
pwrTotal = np.sum(psd1D)
psd1D = psd1D / pwrTotal
return angF, psd1D
[docs]def plot_2D_rate_map(
x,
y,
rate_map,
x0=None,
y0=None,
peak_rate=None,
title=None,
fft_vmax=10.0,
density_bin_size=10.0,
**kwargs,
):
"""
:param x: array
:param y: array
:param rate_map: array
:param peak_rate: float
:param title: str
"""
fig_options = copy.copy(default_fig_options)
fig_options.update(kwargs)
if peak_rate is None:
peak_rate = np.max(rate_map)
fig = plt.figure(constrained_layout=True, figsize=fig_options.figSize)
gs = gridspec.GridSpec(2, 3, figure=fig, width_ratios=[2, 1, 1])
x_min = np.min(x)
x_max = np.max(x)
y_min = np.min(y)
y_max = np.max(y)
ax = fig.add_subplot(gs[0, 0])
pc = ax.pcolor(
x, y, rate_map, vmin=0.0, vmax=peak_rate, cmap=fig_options.colormap
)
cbar = fig.colorbar(pc, ax=ax, fraction=0.046, pad=0.04)
cbar.set_label(
"Firing Rate (Hz)",
rotation=270.0,
labelpad=20.0,
fontsize=fig_options.fontSize,
)
ax.set_title("Rate Map")
ax.set_aspect("equal")
ax.set_xlabel("X Position (cm)", fontsize=fig_options.fontSize)
ax.set_ylabel("Y Position (cm)", fontsize=fig_options.fontSize)
ax.tick_params(labelsize=fig_options.fontSize)
clean_axes(ax)
if x0 is not None and y0 is not None:
ax = fig.add_subplot(gs[0, 1])
ax.set_title("Point Density")
plot_2D_point_density(np.column_stack((x0, y0)), ax=ax)
ax = fig.add_subplot(gs[1, 0])
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
psd2D = np.abs(
np.fft.fftshift(
np.fft.fft2(rate_map - np.mean(rate_map)) / rate_map.shape[0]
)
)
im = ax.imshow(
psd2D,
vmax=fft_vmax,
cmap=fig_options.colormap,
extent=[x_min, x_max, y_min, y_max],
)
cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
cbar.set_label(
"Power", rotation=270.0, labelpad=20.0, fontsize=fig_options.fontSize
)
ax.set_title("Rate Periodogram")
ax.set_aspect("equal")
angF, rpsd = get_RPSD(psd2D)
if x0 is not None and y0 is not None:
ax = fig.add_subplot(gs[1, 1])
else:
ax = fig.add_subplot(gs[:, 1])
sct = ax.scatter(angF, rpsd, cmap=fig_options.colormap)
ax.set_title("Radially Averaged Spectrogram")
if title is not None:
fig.suptitle(title, fontsize=fig_options.fontSize)
if fig_options.saveFig is not None:
save_figure(fig_options.saveFig, fig=fig, **fig_options())
if fig_options.showFig:
plt.show()
return fig
[docs]def plot_2D_histogram(
hist,
x_edges,
y_edges,
norm=None,
ylabel=None,
xlabel=None,
title=None,
cbar_label=None,
cbar=True,
vmin=0.0,
vmax=None,
**kwargs,
):
"""
:param hist: ndarray
:param x_edges: ndarray
:param y_edges: ndarray
:param norm: ndarray; optionally normalize hist by nonzero elements of norm array
:param ylabel: str
:param xlabel: str
:param title: str
:param cbar_label: str
:param cbar: bool
:param vmin: float
:param vmax: float
"""
fig_options = copy.copy(default_fig_options)
fig_options.update(kwargs)
H = np.copy(hist)
if norm is not None:
non_zero = np.where(norm > 0.0)
H[non_zero] = np.divide(H[non_zero], norm[non_zero])
if vmax is None:
vmax = np.max(H)
fig, axes = plt.subplots(figsize=fig_options.figSize)
pcm_cmap = None
pcm_boundaries = np.arange(vmin, vmax, 0.1)
if len(pcm_boundaries) > 0:
cmap_pls = plt.cm.get_cmap(fig_options.colormap, len(pcm_boundaries))
pcm_colors = list(cmap_pls(np.arange(len(pcm_boundaries))))
pcm_cmap = mpl.colors.ListedColormap(pcm_colors[:-1], "")
pcm_cmap.set_under(pcm_colors[0], alpha=0.0)
pcm = axes.pcolormesh(
x_edges, y_edges, H.T, vmin=vmin, vmax=vmax, cmap=pcm_cmap
)
axes.set_aspect("equal")
axes.tick_params(labelsize=fig_options.fontSize)
divider = make_axes_locatable(axes)
cax = divider.append_axes("right", size="2.5%", pad=0.1)
if cbar:
cb = fig.colorbar(pcm, cax=cax)
cb.ax.tick_params(labelsize=fig_options.fontSize)
if cbar_label is not None:
cb.set_label(
cbar_label,
rotation=270.0,
labelpad=20.0,
fontsize=fig_options.fontSize,
)
if xlabel is not None:
axes.set_xlabel(xlabel, fontsize=fig_options.fontSize)
if ylabel is not None:
axes.set_ylabel(ylabel, fontsize=fig_options.fontSize)
if title is not None:
axes.set_title(title, fontsize=fig_options.fontSize)
clean_axes(axes)
if fig_options.saveFig is not None:
save_figure(fig_options.saveFig, fig=fig, **fig_options())
if fig_options.showFig:
plt.show()
return fig
def plot_2D_point_density(data, width=100, height=100, ax=None, inc=0.3):
def points_image(data, height, width, inc=0.3):
xlims = (data[:, 0].min(), data[:, 0].max())
ylims = (data[:, 1].min(), data[:, 1].max())
dxl = xlims[1] - xlims[0]
dyl = ylims[1] - ylims[0]
img = np.zeros((height + 1, width + 1))
for i, p in enumerate(data):
x0 = int(round(((p[0] - xlims[0]) / dxl) * width))
y0 = int(round((1 - (p[1] - ylims[0]) / dyl) * height))
img[y0, x0] += inc
if img[y0, x0] > 1.0:
img[y0, x0] = 1.0
return xlims, ylims, img
if width is None:
width = int(round(data[:, 0].max() - data[:, 0].min()))
if height is None:
height = int(round(data[:, 1].max() - data[:, 1].min()))
xlims, ylims, img = points_image(data, height=height, width=width, inc=inc)
ax_extent = list(xlims) + list(ylims)
if ax is None:
fig, ax = plt.subplots()
ax.imshow(
img,
vmin=0,
vmax=1,
cmap=plt.get_cmap("hot"),
interpolation="hermite",
aspect="auto",
extent=ax_extent,
)
## Plot intracellular state trace
[docs]def plot_intracellular_state(
input_path,
namespace_ids,
include=["eachPop"],
time_range=None,
time_variable="t",
state_variable="v",
max_units=1,
gid_set=None,
n_trials=1,
labels="legend",
lowpass_plot=None,
reduce=False,
distance=False,
**kwargs,
):
"""
Line plot of intracellular state variable (default: v). Returns the figure handle.
input_path: file with state data
namespace_ids: attribute namespaces
time_range ([start:stop]): Time range of spikes shown; if None shows all (default: None)
time_variable: Name of variable containing spike times (default: 't')
state_variable: Name of state variable (default: 'v')
max_units (int): maximum number of units from each population that will be plotted (default: 1)
labels = ('legend', 'overlay'): Show population labels in a legend or overlayed on one side of raster (default: 'legend')
"""
if reduce and distance:
raise RuntimeError(
"plot_intracellular_state: reduce and distance are mutually exclusive"
)
fig_options = copy.copy(default_fig_options)
fig_options.update(kwargs)
(population_ranges, N) = read_population_ranges(input_path)
population_names = read_population_names(input_path)
pop_num_cells = {}
for k in population_names:
pop_num_cells[k] = population_ranges[k][1]
_, state_info = statedata.query_state(
input_path, population_names, namespace_ids=namespace_ids
)
# Replace 'eachPop' with list of populations
if "eachPop" in include:
include.remove("eachPop")
for pop in population_names:
include.append(pop)
if gid_set is None:
for population in include:
for namespace in namespace_ids:
if (population in state_info) and (
namespace in state_info[population]
):
ns_state_info_dict = dict(state_info[population][namespace])
if state_variable in ns_state_info_dict:
gid_set = list(ns_state_info_dict[state_variable])[
:max_units
]
break
else:
raise RuntimeError(
"unable to find recording for state variable %s population %s namespace %s"
% (state_variable, population, namespace)
)
pop_states_dict = defaultdict(lambda: defaultdict(lambda: dict()))
for namespace_id in namespace_ids:
logger.info(f"Reading state values from namespace {namespace_id}...")
data = statedata.read_state(
input_path,
include,
namespace_id,
time_variable=time_variable,
state_variables=[state_variable],
time_range=time_range,
max_units=max_units,
gid=gid_set,
n_trials=n_trials,
)
states = data["states"]
for pop_name, pop_states in states.items():
for gid, cell_states in pop_states.items():
pop_states_dict[pop_name][gid][namespace_id] = cell_states
pop_state_mat_dict = defaultdict(lambda: dict())
for pop_name, pop_states in pop_states_dict.items():
for gid, cell_state_dict in pop_states.items():
nss = sorted(cell_state_dict.keys())
cell_state_x = cell_state_dict[nss[0]][time_variable]
cell_state_mat = np.matrix(
[
np.mean(
np.row_stack(cell_state_dict[ns][state_variable]),
axis=0,
)
for ns in nss
],
dtype=np.float32,
)
cell_state_distances = [
cell_state_dict[ns]["distance"] for ns in nss
]
cell_state_ri = [cell_state_dict[ns]["ri"] for ns in nss]
cell_state_labels = [f"{ns} {state_variable}" for ns in nss]
pop_state_mat_dict[pop_name][gid] = (
cell_state_x,
cell_state_mat,
cell_state_labels,
cell_state_distances,
cell_state_ri,
)
stplots = []
fig, ax, ax_lowpass = None, None, None
if lowpass_plot is None:
fig, ax = plt.subplots(
figsize=fig_options.figSize, sharex="all", sharey="all"
)
elif lowpass_plot == "subplot":
fig, (ax, ax_lowpass) = plt.subplots(
nrows=2, figsize=fig_options.figSize, sharex="all", sharey="all"
)
else:
fig, ax = plt.subplots(
figsize=fig_options.figSize, sharex="all", sharey="all"
)
ax_lowpass = ax
legend_labels = []
for pop_name, pop_states in pop_state_mat_dict.items():
for gid, cell_state_mat in pop_states.items():
m, n = cell_state_mat[1].shape
st_x = cell_state_mat[0][0].reshape((n,))
if distance:
cell_state_distances = cell_state_mat[3]
logger.info(f"cell_state_distances = {cell_state_distances}")
cell_state_ri = cell_state_mat[4]
distance_rank = np.argsort(cell_state_distances, kind="stable")
distance_rank_descending = distance_rank[::-1]
state_rows = []
for i in range(0, m):
j = distance_rank_descending[i]
state_rows.append(
np.asarray(cell_state_mat[1][j, :]).reshape((n,))
)
state_mat = np.row_stack(state_rows)
d = np.asarray(cell_state_distances)[distance_rank_descending]
ri = np.asarray(cell_state_ri)[distance_rank_descending]
pcm = ax.pcolormesh(
st_x, d, state_mat, cmap=fig_options.colormap
)
cb = fig.colorbar(pcm, ax=ax, shrink=0.9, aspect=20)
stplots.append(pcm)
legend_labels.append(f"{pop_name} {gid}")
else:
cell_states = [
np.asarray(cell_state_mat[1][i, :]).reshape((n,))
for i in range(m)
]
if reduce:
cell_state = np.mean(np.vstack(cell_states), axis=0)
(line,) = ax.plot(st_x, cell_state)
stplots.append(line)
logger.info(
f"plot_state: min/max/mean value is "
f"{np.min(cell_state):.02f} / {np.max(cell_state):.02f} / "
f"{np.mean(cell_state):.02f}"
)
else:
for i, cell_state in enumerate(cell_states):
(line,) = ax.plot(st_x, cell_state)
stplots.append(line)
logger.info(
f"plot_state: min/max/mean value of state {i} is "
f"{np.min(cell_state):.02f} / {np.max(cell_state):.02f} "
f"/ {np.mean(cell_state):.02f}"
)
if cell_state_mat[3][i] is not None:
legend_labels.append(
f"{pop_name} {gid} "
f"{cell_state_mat[2][i]} ({cell_state_mat[3][i]:.02f} um)"
)
else:
legend_labels.append(
f"{pop_name} {gid} " f"{cell_state_mat[2][i]}"
)
if lowpass_plot is not None and not distance:
try:
filtered_cell_states = [
get_low_pass_filtered_trace(cell_state, st_x)
for cell_state in cell_states
]
mean_filtered_cell_state = np.mean(
filtered_cell_states, axis=0
)
ax_lowpass.plot(
st_x,
mean_filtered_cell_state,
label=f"{pop_name} {gid} (filtered)",
linewidth=fig_options.lw,
alpha=0.75,
)
except:
pass
ax.set_xlabel("Time [ms]", fontsize=fig_options.fontSize)
if distance:
ax.set_ylabel(
"distance from soma [um]", fontsize=fig_options.fontSize
)
else:
ax.set_ylabel(state_variable, fontsize=fig_options.fontSize)
# ax.legend()
# Add legend
if labels == "legend":
lgd = plt.legend(
stplots,
legend_labels,
fontsize=fig_options.fontSize,
scatterpoints=1,
markerscale=5.0,
loc="upper right",
bbox_to_anchor=(0.5, 1.0),
)
## From https://stackoverflow.com/questions/30413789/matplotlib-automatic-legend-outside-plot
## draw the legend on the canvas to assign it real pixel coordinates:
plt.gcf().canvas.draw()
## transformation from pixel coordinates to Figure coordinates:
transfig = plt.gcf().transFigure.inverted()
## Get the legend extents in pixels and convert to Figure coordinates.
## Pull out the farthest extent in the x direction since that is the canvas direction we need to adjust:
lgd_pos = lgd.get_window_extent()
lgd_coord = transfig.transform(lgd_pos)
lgd_xmax = lgd_coord[1, 0]
## Do the same for the Axes:
ax_pos = plt.gca().get_window_extent()
ax_coord = transfig.transform(ax_pos)
ax_xmax = ax_coord[1, 0]
## Adjust the Figure canvas using tight_layout for
## Axes that must move over to allow room for the legend to fit within the canvas:
shift = 1 - (lgd_xmax - ax_xmax)
plt.gcf().tight_layout(rect=(0, 0, shift, 1))
# save figure
if fig_options.saveFig:
if isinstance(fig_options.saveFig, str):
filename = fig_options.saveFig
else:
filename = input_path + " " + f"state.{fig_options.figFormat}"
plt.savefig(filename)
# show fig
if fig_options.showFig:
show_figure()
return fig
[docs]def plot_network_clamp(
input_path,
spike_namespace,
intracellular_namespace,
gid,
target_input_features_path=None,
target_input_features_namespace=None,
target_input_features_arena_id=None,
target_input_features_trajectory_id=None,
config_file=None,
config_prefix="",
include="eachPop",
include_artificial=True,
time_range=None,
time_variable="t",
intracellular_variable="v",
labels="overlay",
pop_rates=True,
all_spike_hist=False,
spike_hist_bin=5,
lowpass_plot_type="overlay",
n_trials=-1,
marker=".",
opt_seed=None,
**kwargs,
):
"""
Raster plot of target cell intracellular trace + spike raster of presynaptic inputs. Returns the figure handle.
input_path: file with spike data
spike_namespace: attribute namespace for spike events
intracellular_namespace: attribute namespace for intracellular trace
target_input_features_path: optional file with input features
target_input_features_namespaces: optional attribute namespace for input features
config_path: path to network configuration file; required when target_input_features_path is given
time_range ([start:stop]): Time range of spikes shown; if None shows all (default: None)
time_variable: Name of variable containing spike times (default: 't')
labels = ('legend', 'overlay'): Show population labels in a legend or overlayed on one side of raster (default: 'legend')
pop_rates = (True|False): Include population rates (default: False)
spike_hist_bin (int): Size of bin in ms to use for histogram (default: 5)
marker (char): Marker for each spike (default: '.')
"""
fig_options = copy.copy(default_fig_options)
fig_options.update(kwargs)
(population_ranges, N) = read_population_ranges(input_path)
population_names = read_population_names(input_path)
_, state_info = statedata.query_state(
input_path, population_names, namespace_ids=[intracellular_namespace]
)
state_pop_name = None
pop_num_cells = {}
pop_start_inds = {}
for population in population_names:
pop_start_inds[population] = population_ranges[population][0]
pop_range = population_ranges[population]
pop_num_cells[population] = pop_range[1]
if gid is None:
for population in state_info.keys():
if intracellular_namespace in state_info[population]:
state_pop_name = population
gid = dict(state_info[population][intracellular_namespace])[
intracellular_variable
][0]
break
else:
for population in population_names:
pop_range = population_ranges[population]
if (gid >= pop_range[0]) and (gid < pop_range[0] + pop_range[1]):
state_pop_name = population
break
# Replace 'eachPop' with list of populations
if "eachPop" in include:
include.remove("eachPop")
for pop in population_names:
include.append(pop)
spk_include = include
if (state_pop_name is not None) and (state_pop_name not in spk_include):
spk_include.append(state_pop_name)
# sort according to start index
include.sort(key=lambda x: pop_start_inds[x])
include.reverse()
sys.stdout.flush()
spkdata = spikedata.read_spike_events(
input_path,
spk_include,
spike_namespace,
spike_train_attr_name=time_variable,
time_range=time_range,
n_trials=n_trials,
include_artificial=include_artificial,
)
logger.info(
"plot_network_clamp: reading recorded intracellular variable %s for gid %d"
% (intracellular_variable, gid)
)
indata = statedata.read_state(
input_path,
[state_pop_name],
intracellular_namespace,
time_variable=time_variable,
state_variables=[intracellular_variable],
time_range=time_range,
gid=[gid],
n_trials=n_trials,
)
spkpoplst = spkdata["spkpoplst"]
spkindlst = spkdata["spkindlst"]
spktlst = spkdata["spktlst"]
num_cell_spks = spkdata["num_cell_spks"]
pop_active_cells = spkdata["pop_active_cells"]
tmin = spkdata["tmin"]
tmax = spkdata["tmax"]
n_trials = spkdata["n_trials"]
if time_range is None:
time_range = [tmin, tmax]
if (
time_range[0] == time_range[1]
or time_range[0] == float("inf")
or time_range[1] == float("inf")
):
raise RuntimeError(
f"plot_network_clamp: invalid time_range: {time_range}"
)
time_bins = np.arange(time_range[0], time_range[1], spike_hist_bin)
baks_config = copy.copy(kwargs)
target_rate = None
target_rate_time = None
target_rate_ip = None
if (target_input_features_path is not None) and (
target_input_features_namespace is not None
):
if config_file is None:
raise RuntimeError(
"plot_network_clamp: config_file must be provided with target_input_features_path."
)
env = Env(
config_file=config_file,
arena_id=target_input_features_arena_id,
trajectory_id=target_input_features_trajectory_id,
config_prefix=config_prefix,
)
if env.analysis_config is not None:
baks_config.update(env.analysis_config["Firing Rate Inference"])
target_trj_rate_maps = stimulus.rate_maps_from_features(
env,
state_pop_name,
cell_index_set=[gid],
input_features_path=target_input_features_path,
input_features_namespace=target_input_features_namespace,
time_range=time_range,
include_time=True,
)
target_rate_time, target_rate = target_trj_rate_maps[gid]
target_rate_ip = interpolate.Akima1DInterpolator(
target_rate_time, target_rate
)
maxN = 0
minN = N
avg_rates = {}
tsecs = (time_range[1] - time_range[0]) / 1e3
for i, pop_name in enumerate(spkpoplst):
pop_num = len(pop_active_cells[pop_name])
maxN = max(maxN, max(pop_active_cells[pop_name]))
minN = min(minN, min(pop_active_cells[pop_name]))
if pop_num > 0:
if num_cell_spks[pop_name] == 0:
avg_rates[pop_name] = 0
else:
avg_rates[pop_name] = (
num_cell_spks[pop_name] / pop_num / n_trials
) / tsecs
pop_colors = {
pop_name: dflt_colors[ipop % len(dflt_colors)]
for ipop, pop_name in enumerate(spkpoplst)
}
pop_spk_dict = {
pop_name: (pop_spkinds, pop_spkts)
for (pop_name, pop_spkinds, pop_spkts) in zip(
spkpoplst, spkindlst, spktlst
)
}
N = pop_num_cells[pop_name]
S = pop_start_inds[pop_name]
n_plots = len(spkpoplst) + 2
plot_height_ratios = [1] * len(spkpoplst)
if all_spike_hist:
n_plots += 1
plot_height_ratios.append(1)
# Target spike plot
plot_height_ratios.append(1)
if target_rate_ip is not None:
n_plots += 2
plot_height_ratios.append(0.5)
plot_height_ratios.append(0.5)
# State plot
plot_height_ratios.append(2)
if lowpass_plot_type == "subplot":
n_plots += 1
plot_height_ratios.append(1)
fig, axes = plt.subplots(
nrows=n_plots,
sharex=True,
figsize=fig_options.figSize,
gridspec_kw={"height_ratios": plot_height_ratios},
)
stplots = []
def sphist(trial_spkts):
if len(trial_spkts) > 0:
bin_edges = np.histogram_bin_edges(
trial_spkts[0],
bins=np.arange(time_range[0], time_range[1], spike_hist_bin),
)
trial_sphist_ys = np.array(
[
np.histogram(spkts, bins=bin_edges)[0]
for spkts in trial_spkts
]
)
sphist_y = np.mean(trial_sphist_ys, axis=0)
sphist_x = bin_edges[:-1] + (spike_hist_bin / 2)
pch = interpolate.pchip(sphist_x, sphist_y)
res_npts = int(sphist_x.max() - sphist_x.min())
sphist_x_res = np.linspace(
sphist_x.min(), sphist_x.max(), res_npts, endpoint=True
)
sphist_y_res = pch(sphist_x_res)
else:
bin_edges = np.arange(time_range[0], time_range[1], spike_hist_bin)
sphist_x_res = bin_edges[:-1] + (spike_hist_bin / 2)
sphist_y_res = np.zeros(sphist_x_res.shape)
return sphist_x_res, sphist_y_res
for i, pop_name in enumerate(include):
pop_spkinds, pop_spkts = pop_spk_dict.get(pop_name, ([], []))
sphist_x, sphist_y = sphist(pop_spkts)
sph = axes[i].fill_between(
sphist_x,
sphist_y,
linewidth=fig_options.lw,
color=pop_colors.get(pop_name, dflt_colors[0]),
alpha=0.5,
label=pop_name,
)
axes[i].set_ylim(0.0, np.ceil(np.max(sphist_y)))
stplots.append(sph)
if i == 0:
axes[i].set_xlim(time_range)
axes[i].set_xlabel("Time (ms)", fontsize=fig_options.fontSize)
axes[i].set_ylabel("Spike Count", fontsize=fig_options.fontSize)
fig.subplots_adjust(hspace=0)
plt.setp([a.get_xticklabels() for a in fig.axes[:-2]], visible=False)
# set raster plot ticks to the end of the index range for each population
for i, pop_name in enumerate(include):
a = fig.axes[i]
start, end = a.get_ylim()
a.get_yaxis().set_ticks([end])
# set raster plot ticks to start and end of index range for first population
a = fig.axes[len(spkpoplst) - 1]
start, end = a.get_ylim()
a.get_yaxis().set_ticks([start, end])
if pop_rates:
lgd_labels = [
pop_name
+ " (%i active; %.3g Hz)"
% (len(pop_active_cells[pop_name]), avg_rates[pop_name])
for pop_name in spkpoplst
if pop_name in avg_rates
]
else:
lgd_labels = [
pop_name + f" ({len(pop_active_cells[pop_name])} active)"
for pop_name in spkpoplst
if pop_name in avg_rates
]
i_ax = len(spkpoplst)
if spktlst:
if all_spike_hist:
# Calculate and plot total spike histogram
all_trial_spkts = [list() for i in range(len(spktlst[0]))]
for i, pop_name in enumerate(include):
pop_spkinds, pop_spkts = pop_spk_dict.get(pop_name, ([], []))
for trial_i, this_trial_spkts in enumerate(pop_spkts):
all_trial_spkts[trial_i].append(this_trial_spkts)
merged_trial_spkts = [
np.concatenate(trial_spkts, axis=0)
for trial_spkts in all_trial_spkts
]
sphist_x, sphist_y = sphist(merged_trial_spkts)
sprate = np.sum(
avg_rates[pop_name] for pop_name in avg_rates
) / len(avg_rates)
ax_spk = axes[i_ax]
ax_spk.plot(sphist_x, sphist_y, linewidth=1.0)
ax_spk.set_xlabel("Time (ms)", fontsize=fig_options.fontSize)
ax_spk.set_xlim(time_range)
ax_spk.set_ylim((np.min(sphist_y), np.max(sphist_y) * 2))
if pop_rates:
lgd_label = f"mean firing rate: {sprate:.3g} Hz"
at = AnchoredText(
lgd_label,
loc="upper right",
borderpad=0.01,
prop=dict(size=fig_options.fontSizej),
)
ax_spk.add_artist(at)
i_ax += 1
# Calculate and plot spike histogram for target gid
pop_spkinds, pop_spkts = pop_spk_dict.get(state_pop_name, ([], []))
trial_sdf_ips = []
spk_count = 0
ax_spk = axes[i_ax]
for this_trial_spkinds, this_trial_spkts in zip_longest(
pop_spkinds, pop_spkts
):
spk_inds = np.argwhere(this_trial_spkinds == gid)
spk_count += len(spk_inds)
if target_rate_ip is not None:
sdf_dict = spikedata.spike_density_estimate(
state_pop_name,
{gid: this_trial_spkts[spk_inds]},
time_bins,
**baks_config,
)
trial_sdf_rate = sdf_dict[gid]["rate"]
trial_sdf_time = sdf_dict[gid]["time"]
trial_sdf_ip = interpolate.Akima1DInterpolator(
trial_sdf_time, trial_sdf_rate
)
trial_sdf_ips.append(trial_sdf_ip)
if len(spk_inds) > 0:
ax_spk.stem(
this_trial_spkts[spk_inds],
[0.5] * len(spk_inds),
markerfmt=" ",
use_line_collection=True,
)
ax_spk.set_yticks([])
sprate = spk_count / n_trials / tsecs
ax_spk.set_xlabel("Time (ms)", fontsize=fig_options.fontSize)
ax_spk.set_xlim(time_range)
if pop_rates:
lgd_label = "%s gid %d: %.3g Hz" % (state_pop_name, gid, sprate)
at = AnchoredText(
lgd_label,
loc="upper right",
borderpad=0.01,
prop=dict(size=fig_options.fontSize),
)
ax_spk.add_artist(at)
i_ax += 1
if target_rate is not None:
t = np.arange(time_range[0], time_range[1], 1.0)
target_rate_t_range = target_rate_ip(t)
if np.any(np.isnan(target_rate_t_range)):
target_rate_t_range[np.isnan(target_rate_t_range)] = 0.0
vmin, vmax = 0, np.max(target_rate_t_range)
ax_target_rate = axes[i_ax]
i_ax += 1
ax_target_rate.imshow(
target_rate_t_range[np.newaxis, :],
vmin=vmin,
vmax=vmax,
aspect="auto",
)
ax_target_rate.set_yticks([])
ax_mean_sdf = axes[i_ax]
i_ax += 1
if len(trial_sdf_ips) > 0:
trial_sdf_matrix = np.row_stack(
[trial_sdf_ip(t) for trial_sdf_ip in trial_sdf_ips]
)
mean_sdf = np.mean(trial_sdf_matrix, axis=0)
ax_mean_sdf.imshow(
mean_sdf[np.newaxis, :], vmin=vmin, vmax=vmax, aspect="auto"
)
ax_mean_sdf.set_yticks([])
# Plot intracellular state
ax_state = axes[i_ax]
ax_state.set_xlabel("Time (ms)", fontsize=fig_options.fontSize)
ax_state.set_ylabel(intracellular_variable, fontsize=fig_options.fontSize)
ax_state.set_xlim(time_range)
i_ax += 1
# Plot lowpass-filtered intracellular state if lowpass_plot_type is set to subplot
if lowpass_plot_type == "subplot":
ax_lowpass = axes[i_ax]
i_ax += 1
else:
ax_lowpass = ax_state
states = indata["states"]
stvplots = []
for pop_name, pop_states in states.items():
for gid, cell_states in pop_states.items():
st_len = cell_states[intracellular_variable][0].shape[0]
st_xs = [x[:st_len] for x in cell_states[time_variable]]
st_ys = [y[:st_len] for y in cell_states[intracellular_variable]]
st_x = st_xs[0]
try:
filtered_st_ys = [
get_low_pass_filtered_trace(st_y, st_x)
for st_x, st_y in zip(st_xs, st_ys)
]
filtered_st_y = np.mean(filtered_st_ys, axis=0)
ax_lowpass.plot(
st_x,
filtered_st_y,
label=f"{pop_name} (filtered)",
linewidth=fig_options.lw,
alpha=0.75,
)
except:
pass
for st_y in st_ys:
stvplots.append(
ax_state.plot(
st_x,
st_y,
label=pop_name,
linewidth=fig_options.lw,
alpha=0.5,
)
)
if labels == "legend":
# Shrink axes by 15%
for ax in axes:
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.85, box.height])
# Add legend
lgd = fig.legend(
stplots,
lgd_labels,
loc="center right",
fontsize="small",
scatterpoints=1,
markerscale=5.0,
bbox_to_anchor=(1.002, 0.5),
bbox_transform=plt.gcf().transFigure,
)
fig.artists.append(lgd)
elif labels == "overlay":
for i, (pop_name, lgd_label) in enumerate(zip(spkpoplst, lgd_labels)):
at = AnchoredText(
lgd_label,
loc="upper right",
borderpad=0.01,
prop=dict(size=fig_options.fontSize),
)
axes[i].add_artist(at)
max_label_len = max(len(l) for l in lgd_labels)
else:
raise RuntimeError(f"plot_network_clamp: unknown label type {labels}")
# save figure
ts = time.strftime("%Y%m%d_%H%M%S")
if fig_options.saveFig:
if isinstance(fig_options.saveFig, str):
filename = fig_options.saveFig
else:
filename = (
"Network Clamp %s %i.%s"
% (state_pop_name, gid, fig_options.figFormat)
if opt_seed is None
else "NetworkClamp_{!s}_{:d}_{!s}_{:08d}.{!s}".format(
state_pop_name, gid, ts, opt_seed, fig_options.figFormat
)
)
plt.savefig(filename)
# show fig
if fig_options.showFig:
show_figure()
return fig
[docs]def plot_single_vertex_dist(
env,
connectivity_path,
coords_path,
distances_namespace,
target_gid,
destination,
source,
extent_type="local",
direction="in",
bin_size=20.0,
normed=False,
comm=None,
**kwargs,
):
"""
Plot vertex distribution with respect to arc distance for a single postsynaptic cell.
:param env:
:param connectivity_path:
:param coords_path:
:param distances_namespace:
:param target_gid:
:param destination:
:param source:
"""
from miv_simulator.geometry.geometry import measure_distance_extents
fig_options = copy.copy(default_fig_options)
fig_options.update(kwargs)
(population_ranges, _) = read_population_ranges(coords_path)
destination_start = population_ranges[destination][0]
destination_count = population_ranges[destination][1]
if comm is None:
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
source_soma_distances = bcast_cell_attributes(
coords_path, source, namespace=distances_namespace, comm=comm, root=0
)
destination_soma_distances = bcast_cell_attributes(
coords_path,
destination,
namespace=distances_namespace,
comm=comm,
root=0,
)
(
(total_x_min, total_x_max),
(total_y_min, total_y_max),
) = measure_distance_extents(env.geometry, volume=network_volume)
source_soma_distance_U = {}
source_soma_distance_V = {}
destination_soma_distance_U = {}
destination_soma_distance_V = {}
for k, v in source_soma_distances:
source_soma_distance_U[k] = v["U Distance"][0]
source_soma_distance_V[k] = v["V Distance"][0]
for k, v in destination_soma_distances:
destination_soma_distance_U[k] = v["U Distance"][0]
destination_soma_distance_V[k] = v["V Distance"][0]
del source_soma_distances
del destination_soma_distances
g = NeuroH5ProjectionGen(
connectivity_path, source, destination, comm=comm, cache_size=20
)
dist_bins = {}
if direction == "in":
for destination_gid, rest in g:
if destination_gid == target_gid:
(source_indexes, attr_dict) = rest
for source_gid in source_indexes:
dist_u = source_soma_distance_U[source_gid]
dist_v = source_soma_distance_V[source_gid]
update_bins(dist_bins, bin_size, dist_u, dist_v)
break
elif direction == "out":
for destination_gid, rest in g:
if rest is not None:
(source_indexes, attr_dict) = rest
for source_gid in source_indexes:
if source_gid == target_gid:
dist_u = destination_soma_distance_U[destination_gid]
dist_v = destination_soma_distance_V[destination_gid]
update_bins(dist_bins, bin_size, dist_u, dist_v)
else:
raise RuntimeError(f"Unknown direction type {direction}")
add_bins_op = MPI.Op.Create(add_bins, commute=True)
dist_bins = comm.reduce(dist_bins, op=add_bins_op)
if rank == 0:
dist_hist_vals, dist_u_bin_edges, dist_v_bin_edges = finalize_bins(
dist_bins, bin_size
)
dist_x_min = dist_u_bin_edges[0]
dist_x_max = dist_u_bin_edges[-1]
dist_y_min = dist_v_bin_edges[0]
dist_y_max = dist_v_bin_edges[-1]
if extent_type == "local":
x_min = dist_x_min
x_max = dist_x_max
y_min = dist_y_min
y_max = dist_y_max
elif extent_type == "global":
x_min = total_x_min
x_max = total_x_max
y_min = total_y_min
y_max = total_y_max
else:
raise RuntimeError(f"Unknown extent type {extent_type}")
X, Y = np.meshgrid(dist_u_bin_edges, dist_v_bin_edges)
fig = plt.figure(figsize=fig_options.figSize)
ax = plt.gca()
ax.axis([x_min, x_max, y_min, y_max])
if direction == "in":
ax.plot(
destination_soma_distance_U[target_gid],
destination_soma_distance_V[target_gid],
"r+",
markersize=12,
mew=3,
)
elif direction == "out":
ax.plot(
source_soma_distance_U[target_gid],
source_soma_distance_V[target_gid],
"r+",
markersize=12,
mew=3,
)
else:
raise RuntimeError(f"Unknown direction type {direction}")
H = np.array(dist_hist_vals.todense())
if normed:
H = np.divide(H.astype(float), float(np.max(H)))
pcm_boundaries = np.arange(0, np.max(H), 0.1)
cmap_pls = plt.cm.get_cmap(fig_options.colormap, len(pcm_boundaries))
pcm_colors = list(cmap_pls(np.arange(len(pcm_boundaries))))
pcm_cmap = mpl.colors.ListedColormap(pcm_colors[:-1], "")
pcm_cmap.set_under(pcm_colors[0], alpha=0.0)
pcm = ax.pcolormesh(X, Y, H.T, cmap=pcm_cmap)
clb_label = (
"Normalized number of connections"
if normed
else "Number of connections"
)
clb = fig.colorbar(pcm, ax=ax, shrink=0.5, label=clb_label)
clb.ax.tick_params(labelsize=fig_options.fontSize)
ax.set_aspect("equal")
ax.set_facecolor(pcm_colors[0])
ax.tick_params(labelsize=fig_options.fontSize)
ax.set_xlabel(
"Longitudinal position (um)", fontsize=fig_options.fontSize
)
ax.set_ylabel("Transverse position (um)", fontsize=fig_options.fontSize)
ax.set_title(
f"Connectivity distribution ({direction}) of "
f"{source} to {destination} for gid {target_gid}",
fontsize=fig_options.fontSize,
)
if fig_options.showFig:
show_figure()
if fig_options.saveFig:
if isinstance(fig_options.saveFig, str):
filename = fig_options.saveFig
else:
filename = f"Connection distance {direction} {source} to {destination} gid {target_gid}.{fig_options.figFormat}"
plt.savefig(filename)
comm.barrier()
def update_spatial_rasters(
frame,
scts,
timebins,
n_trials,
data,
distances_U_dict,
distances_V_dict,
lgd,
):
N = len(timebins)
if frame > 0:
t0 = timebins[frame % N]
t1 = timebins[(frame + 1) % N]
trial = frame // N
for p, (pop_name, spkinds, spkts) in enumerate(data):
distances_U = distances_U_dict[pop_name]
distances_V = distances_V_dict[pop_name]
rinds = np.where(
np.logical_and(spkts[trial] >= t0, spkts[trial] <= t1)
)
cinds = spkinds[trial][rinds]
x = np.asarray([distances_U[ind] for ind in cinds])
y = np.asarray([distances_V[ind] for ind in cinds])
scts[p].set_data(x, y)
scts[p].set_label(pop_name)
if n_trials > 1:
scts[-1].set_text(f"trial {trial}; t = {t1:.02f} ms")
else:
scts[-1].set_text(f"t = {t1:.02f} ms")
return scts
def init_spatial_rasters(
ax,
timebins,
n_trials,
data,
range_U_dict,
range_V_dict,
distances_U_dict,
distances_V_dict,
lgd,
marker,
pop_colors,
**kwargs,
):
fig_options = copy.copy(default_fig_options)
fig_options.update(kwargs)
scts = []
t0 = timebins[0]
t1 = timebins[1]
min_U = None
min_V = None
max_U = None
max_V = None
for pop_name, spkinds, spkts in data:
distances_U = distances_U_dict[pop_name]
distances_V = distances_V_dict[pop_name]
rinds = np.where(np.logical_and(spkts[0] >= t0, spkts[0] <= t1))
cinds = spkinds[0][rinds]
x = np.asarray([distances_U[ind] for ind in cinds])
y = np.asarray([distances_V[ind] for ind in cinds])
# scts.append(ax.scatter(x, y, linewidths=options.lw, marker=marker, c=pop_colors[pop_name], alpha=0.5, label=pop_name))
scts = scts + plt.plot([], [], marker, animated=True, alpha=0.5)
if min_U is None:
min_U = range_U_dict[pop_name][0]
else:
min_U = min(min_U, range_U_dict[pop_name][0])
if min_V is None:
min_V = range_V_dict[pop_name][0]
else:
min_V = min(min_V, range_V_dict[pop_name][0])
if max_U is None:
max_U = range_U_dict[pop_name][1]
else:
max_U = max(max_U, range_U_dict[pop_name][1])
if max_V is None:
max_V = range_V_dict[pop_name][1]
else:
max_V = max(max_V, range_V_dict[pop_name][1])
ax.set_xlim((min_U, max_U))
ax.set_ylim((min_V, max_V))
return scts + [
lgd(scts),
plt.text(
0.05,
0.95,
"t = %f ms" % t0,
fontsize=fig_options.fontSize,
transform=ax.transAxes,
),
]
spatial_raster_aniplots = []
## Plot spike raster
[docs]def plot_spatial_spike_raster(
input_path,
namespace_id,
coords_path,
distances_namespace="Arc Distances",
include=["eachPop"],
time_step=5.0,
time_range=None,
time_variable="t",
include_artificial=True,
max_spikes=int(1e6),
marker="o",
**kwargs,
):
"""
Spatial raster plot of network spike times. Returns the figure handle.
input_path: file with spike data
namespace_id: attribute namespace for spike events
time_range ([start:stop]): Time range of spikes shown; if None shows all (default: None)
time_variable: Name of variable containing spike times (default: 't')
max_spikes (int): maximum number of spikes that will be plotted (default: 1e6)
labels = ('legend', 'overlay'): Show population labels in a legend or overlayed on one side of raster (default: 'legend')
marker (char): Marker for each spike (default: '|')
"""
fig_options = copy.copy(default_fig_options)
fig_options.update(kwargs)
(population_ranges, N) = read_population_ranges(input_path)
population_names = read_population_names(input_path)
pop_num_cells = {}
for k in population_names:
pop_num_cells[k] = population_ranges[k][1]
# Replace 'eachPop' with list of populations
if "eachPop" in include:
include.remove("eachPop")
for pop in population_names:
include.append(pop)
distance_U_dict = {}
distance_V_dict = {}
range_U_dict = {}
range_V_dict = {}
for population in include:
distances = read_cell_attributes(
coords_path, population, namespace=distances_namespace
)
soma_distances = {
k: (v["U Distance"][0], v["V Distance"][0]) for (k, v) in distances
}
del distances
logger.info("read distances (%i elements)" % len(soma_distances.keys()))
distance_U_array = np.asarray(
[soma_distances[gid][0] for gid in soma_distances]
)
distance_V_array = np.asarray(
[soma_distances[gid][1] for gid in soma_distances]
)
U_min = np.min(distance_U_array)
U_max = np.max(distance_U_array)
V_min = np.min(distance_V_array)
V_max = np.max(distance_V_array)
range_U_dict[population] = (U_min, U_max)
range_V_dict[population] = (V_min, V_max)
distance_U = {gid: soma_distances[gid][0] for gid in soma_distances}
distance_V = {gid: soma_distances[gid][1] for gid in soma_distances}
distance_U_dict[population] = distance_U
distance_V_dict[population] = distance_V
spkdata = spikedata.read_spike_events(
input_path,
include,
namespace_id,
spike_train_attr_name=time_variable,
time_range=time_range,
include_artificial=include_artificial,
)
n_trials = spkdata["n_trials"]
spkpoplst = spkdata["spkpoplst"]
spkindlst = spkdata["spkindlst"]
spktlst = spkdata["spktlst"]
num_cell_spks = spkdata["num_cell_spks"]
pop_active_cells = spkdata["pop_active_cells"]
tmin = spkdata["tmin"]
tmax = spkdata["tmax"]
time_range = [tmin, tmax]
pop_colors = {
pop_name: dflt_colors[ipop % len(dflt_colors)]
for ipop, pop_name in enumerate(spkpoplst)
}
# Plot spikes
fig, ax = plt.subplots(figsize=fig_options.figSize)
pop_labels = [pop_name for pop_name in spkpoplst]
legend_labels = pop_labels
lgd = lambda objs: plt.legend(
objs,
legend_labels,
fontsize=fig_options.fontSize,
scatterpoints=1,
markerscale=2.0,
loc="upper right",
bbox_to_anchor=(0.95, 0.95),
)
timebins = np.linspace(tmin, tmax, int(((tmax - tmin) / time_step)))
data = list(zip(spkpoplst, spkindlst, spktlst))
scts = init_spatial_rasters(
ax,
timebins,
n_trials,
data,
range_U_dict,
range_V_dict,
distance_U_dict,
distance_V_dict,
lgd,
marker,
pop_colors,
)
ani = FuncAnimation(
fig,
func=update_spatial_rasters,
frames=list(range(0, len(timebins) * n_trials - 1)),
blit=True,
repeat=False,
init_func=lambda: scts,
fargs=(
scts,
timebins,
n_trials,
data,
distance_U_dict,
distance_V_dict,
lgd,
),
)
spatial_raster_aniplots.append(ani)
# show fig
if fig_options.showFig:
show_figure()
if fig_options.saveFig:
Writer = writers["ffmpeg"]
writer = Writer(fps=15, metadata=dict(artist="Me"), bitrate=1800)
ani.save(f"{namespace_id} spatial raster.mp4", writer=writer)
return fig