Separation property of a reservoir culture#

Note

This example uses the imperative interface.

This example constructs a microcircuit reservoir simulation that reproduces a classic seperation property experiment by Maass et al. 2002.

To run this example, make sure to change into the 3-interface-api directory within the MiV-Simulator-Cases repository. Then run the entire rc-separation.py script or execute each individual step below.

  1# %% [markdown]
  2# ## Create the network
  3
  4# %%
  5from machinable import get
  6from miv_simulator.utils import from_yaml
  7
  8synapses_config = from_yaml("simulation/config/synapses.yml")
  9
 10h5_types = get(
 11    "miv_simulator.interface.h5_types",
 12    [
 13        {
 14            "cell_distributions": {
 15                "STIM": {"SO": 0, "SP": 64, "SR": 0, "SLM": 0},
 16                "PYR": {"SO": 0, "SP": 223, "SR": 0, "SLM": 0},
 17                "PVBC": {"SO": 35, "SP": 50, "SR": 8, "SLM": 0},
 18                "OLM": {"SO": 21, "SP": 0, "SR": 0, "SLM": 0},
 19            },
 20            "projections": {
 21                post: list(pre.keys()) for post, pre in synapses_config.items()
 22            },
 23        },
 24    ],
 25).launch()
 26
 27network = get(
 28    "miv_simulator.interface.network_architecture",
 29    {
 30        "filepath": h5_types.output_filepath,
 31        "cell_distributions": h5_types.config.cell_distributions,
 32        "layer_extents": {
 33            "SO": [[0.0, 0.0, 0.0], [200.0, 200.0, 5.0]],
 34            "SP": [[0.0, 0.0, 5.0], [200.0, 200.0, 50.0]],
 35            "SR": [[0.0, 0.0, 50.0], [200.0, 200.0, 100.0]],
 36            "SLM": [[0.0, 0.0, 100.0], [200.0, 200.0, 150.0]],
 37        },
 38    },
 39    uses=h5_types,
 40).launch()
 41
 42measure_distances = network.measure_distances().launch()
 43
 44synapse_forest = {
 45    population: network.generate_synapse_forest(
 46        {
 47            "population": population,
 48            "morphology": f"./simulation/morphology/{population}.swc",
 49        },
 50        uses=measure_distances,
 51    ).launch()
 52    for population in ["PYR", "PVBC", "OLM"]
 53}
 54
 55synapses = {
 56    population: network.distribute_synapses(
 57        {
 58            "forest_filepath": synapse_forest[population].output_filepath,
 59            "cell_types": "from_file('simulation/config/cell_types.yml')",
 60            "population": population,
 61            "distribution": "poisson",
 62            "mechanisms_path": "./simulation/mechanisms",
 63            "template_path": "./simulation/templates",
 64            "io_size": 1,
 65            "write_size": 0,
 66        },
 67        uses=list(synapse_forest.values()),
 68    ).launch()
 69    for population in ["PYR", "PVBC", "OLM"]
 70}
 71
 72connections = {
 73    population: network.generate_connections(
 74        {
 75            "synapses": synapses_config,
 76            "forest_filepath": synapses[population].output_filepath,
 77            "axon_extents": {
 78                "STIM": {"default": {"width": [200, 200], "offset": [0, 0]}},
 79                "PYR": {"default": {"width": [200, 200], "offset": [0, 0]}},
 80                "PVBC": {"default": {"width": [200, 200], "offset": [0, 0]}},
 81                "OLM": {"default": {"width": [200, 200], "offset": [0, 0]}},
 82            },
 83            "template_path": "./simulation/templates",
 84            "io_size": 1,
 85            "cache_size": 20,
 86            "write_size": 100,
 87        },
 88        uses=list(synapses.values()),
 89    ).launch()
 90    for population in ["PYR", "PVBC", "OLM"]
 91}
 92
 93graph = get(
 94    "miv_simulator.interface.neuroh5_graph",
 95    uses=[
 96        network,
 97        *synapse_forest.values(),
 98        *synapses.values(),
 99        *connections.values(),
100    ],
101).launch()
102
103graph.files()
104
105# %% [markdown]
106# ## Separation property experiment
107#
108# Following Maass et al. 2002 (Figure 2).
109
110# %%
111import numpy as np
112from collections import defaultdict
113from miv_simulator import coding
114from machinable import Component
115
116
117def generate_poisson_spike_train(rate_hz, duration_ms) -> coding.SpikeTimes:
118    duration_s = duration_ms / 1000.0
119    ISIs = np.random.exponential(1.0 / rate_hz, int(rate_hz * duration_s * 1.5))
120    spike_times = np.cumsum(ISIs)
121    spike_times = spike_times[spike_times < duration_s]
122    return coding.cast_spike_times(spike_times * 1000)
123
124
125def gaussian_conv(spike_times, duration, tau=5, dt=0.1):
126    t = np.arange(0, duration + dt, dt)
127    spike_function = np.zeros_like(t)
128    for spike in spike_times:
129        spike_function[int(spike / dt)] = 1
130    w = 3
131    x = np.arange(-w * tau, w * tau + dt, dt)
132    kernel = np.exp(-((x / tau) ** 2))
133    return np.convolve(spike_function, kernel, mode="same")
134
135
136def spike_train_distance(u_times, v_times, duration, tau=5):
137    """d(u, v) in Maass+2002"""
138    u_continuous = gaussian_conv(u_times, duration, tau)
139    v_continuous = gaussian_conv(v_times, duration, tau)
140    distance = np.linalg.norm(u_continuous - v_continuous)
141    return distance / duration
142
143
144class SpikeTrainPairs(Component):
145    class Config:
146        N: int = 200
147        distance: float = 0.1
148        splits: int = 10
149        duration: float = 500
150
151    def __call__(self):
152        distances = [self.config.distance]
153
154        g = lambda r: generate_poisson_spike_train(r, self.config.duration)
155        d = lambda a, b: spike_train_distance(
156            a, b, duration=self.config.duration
157        )
158
159        # figure out frequencies for which distances are likely
160        m = {t: (0, None) for t in distances}
161        for ff in range(
162            1, 100, 1
163        ):  # 5Hz spontanous, so around 10Hz might be reasonable
164            f = ff * self.config.splits
165            dd = np.mean([d(g(f), g(f)) for _ in range(50)])
166            print(f, dd)
167            for k, v in m.items():
168                if k - v[0] > dd - v[0]:
169                    m[k] = (dd, f)
170        print(f"Frequency-distance map: {m}")
171
172        # plt.hist([d(g(), g()) for _ in range(200)], bins=20)
173        # plt.show()
174
175        eps = 0.01
176        data = defaultdict(list)
177        for target in distances:
178            f = m[target][1]
179            while len(data[target]) < self.config.N:
180                u, v = g(f), g(f)
181                dd = d(u, v)
182                if np.abs(dd - target) < eps:
183                    data[target].append((u, v))
184
185        self.save_file("data.p", data[self.config.distance])
186
187    def data(self):
188        return self.load_file("data.p", None)
189
190
191with get("interface.execution.slurm"):
192    spike_trains = {
193        distance: get(SpikeTrainPairs, {"distance": distance}).launch()
194        for distance in [0.1, 0.2, 0.4]
195    }
196
197# %% [markdown]
198# ## Reproduce Figure 2 from Maass et al. 2002
199#
200
201# %%
202
203from matplotlib import pyplot as plt
204
205fig = plt.figure()
206for distance in [
207    0.0,
208    0.1,
209]:  # 0.2, 0.4]:
210    total = 0
211    finished = 0
212    data = spike_trains[
213        distance if distance != 0 else list(spike_trains.keys())[0]
214    ].data()
215    x = []
216    y = []
217    for trial in range(5):  # spike_trains[distance].config.N
218        experiment = {}
219        for u_or_v in range(2):
220            stimulus = data[trial][u_or_v].tolist()
221            context = {}
222            if distance == 0.0:
223                # use the same stimulus u but with different initialization context
224                stimulus = data[trial][0].tolist()
225                context = {"state": u_or_v}
226            with get("machinable.scope", {"trial": trial, **context}):
227                experiment["u" if u_or_v == 0 else "v"] = e = get(
228                    "interface.experiment.rc",
229                    [
230                        graph.files(),
231                        {
232                            "t_end": 500,
233                            "cell_types": "from_file('simulation/config/cell_types.yml')",
234                            "synapses": "from_file('simulation/config/synapses.yml')",
235                            "stimulus": stimulus,
236                        },
237                    ],
238                ).launch()
239                total += 1
240                if e.cached():
241                    finished += 1
242        u = experiment["u"].readout()
243        v = experiment["v"].readout()
244        if u is None or v is None:
245            continue
246        state_distance = np.abs(u[:, 1] - v[:, 1])
247        x = u[:, 0]
248        y.append(state_distance)
249    print(
250        f"For distance {distance}, found {finished}/{total} cached experiments"
251    )
252    if finished != total:
253        continue
254    state_distances = np.array(y)
255    state_distance_avg = y = np.mean(state_distances, axis=0)
256    state_distance_std = error = np.std(state_distances, axis=0)
257    reduced = np.mean(state_distance_avg)
258    plt.plot(x, y, label=f"d(u,v)={distance} (mean={round(reduced, 4)})")
259    # plt.fill_between(x, y - error, y + error)
260
261plt.legend(loc="upper right")