{ "cells": [ { "cell_type": "markdown", "id": "512217dc", "metadata": {}, "source": [ "# Cremi example\n", "This example shows how to use volara to process a realistic dataset. In this case we are using CREMI.\n", "We used volara to generate the predictions blockwise. To see how to use volara to generate predictions\n", "blockwise, see the [volara-torch plugin example](https://e11bio.github.io/volara-torch/examples/cremi/cremi.html)" ] }, { "cell_type": "code", "execution_count": null, "id": "9129ddd0", "metadata": { "lines_to_next_cell": 0 }, "outputs": [], "source": [ "import multiprocessing as mp\n", "\n", "mp.set_start_method(\"fork\", force=True) # type: ignore[call-arg]" ] }, { "cell_type": "code", "execution_count": null, "id": "5aca1e10", "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "from funlib.geometry import Coordinate\n", "\n", "Path(\"_static/cremi\").mkdir(parents=True, exist_ok=True)" ] }, { "cell_type": "markdown", "id": "c4ba20a7", "metadata": { "lines_to_next_cell": 0 }, "source": [ "First things first lets visualize the data we've been given." ] }, { "cell_type": "code", "execution_count": null, "id": "97c7246f", "metadata": {}, "outputs": [], "source": [ "import matplotlib.animation as animation\n", "import matplotlib.pyplot as plt\n", "\n", "from volara.datasets import Affs, Raw\n", "\n", "raw = Raw(store=\"sample_A+_20160601.zarr/raw\", scale_shift=(1 / 255, 0)) # type: ignore[arg-type]\n", "affs = Affs(store=\"sample_A+_20160601.zarr/affs\") # type: ignore[arg-type]\n", "\n", "fig, axes = plt.subplots(1, 3, figsize=(14, 8))\n", "\n", "ims = []\n", "for i, (raw_slice, affs_slice) in enumerate(\n", " zip(raw.array(\"r\")[:], affs.array(\"r\")[:].transpose([1, 0, 2, 3]))\n", "):\n", " # Show the raw data\n", " if i == 0:\n", " im_raw = axes[0].imshow(raw_slice, cmap=\"gray\")\n", " axes[0].set_title(\"Raw\")\n", " im_affs_short = axes[1].imshow(\n", " affs_slice[0:3].transpose([1, 2, 0]), vmin=0, vmax=255, interpolation=\"none\"\n", " )\n", " axes[1].set_title(\"Affs (short range)\")\n", " im_affs_long = axes[2].imshow(\n", " affs_slice[[0, 5, 6]].transpose([1, 2, 0]),\n", " vmin=0,\n", " vmax=255,\n", " interpolation=\"none\",\n", " )\n", " axes[2].set_title(\"Affs (long range)\")\n", " else:\n", " im_raw = axes[0].imshow(raw_slice, cmap=\"gray\", animated=True)\n", " axes[0].set_title(\"Raw\")\n", " im_affs_short = axes[1].imshow(\n", " affs_slice[0:3].transpose([1, 2, 0]),\n", " vmin=0,\n", " vmax=255,\n", " interpolation=\"none\",\n", " animated=True,\n", " )\n", " axes[1].set_title(\"Affs (short range)\")\n", " im_affs_long = axes[2].imshow(\n", " affs_slice[[0, 5, 6]].transpose([1, 2, 0]),\n", " vmin=0,\n", " vmax=255,\n", " interpolation=\"none\",\n", " animated=True,\n", " )\n", " axes[2].set_title(\"Affs (long range)\")\n", " ims.append([im_raw, im_affs_short, im_affs_long])\n", "\n", "ims = ims + ims[::-1]\n", "ani = animation.ArtistAnimation(fig, ims, blit=True)\n", "ani.save(\"_static/cremi/inputs.gif\", writer=\"pillow\", fps=10)\n", "plt.close()" ] }, { "cell_type": "markdown", "id": "bf69dea1", "metadata": {}, "source": [ "![inputs](_static/cremi/inputs.gif)" ] }, { "cell_type": "markdown", "id": "749e7831", "metadata": {}, "source": [ "Now we can convert the results to a segmentation. We will run mutex watershed on the affinities in a multi step process.\n", "1) Local fragment extraction - This step runs blockwise and generates fragments from the affinities. For each fragment we save a node in a graph with attributes such as its spatial position and size.\n", "2) Edge extraction - This step runs blockwise and computes mean affinities between fragments, adding edges to the fragment graph.\n", "3) Graph Mutex Watershed - This step runs on the fragment graph, and creates a lookup table from fragment id -> segment id.\n", "4) Relabel fragments - This step runs blockwise and creates the final segmentation." ] }, { "cell_type": "code", "execution_count": null, "id": "ba731ae2", "metadata": {}, "outputs": [], "source": [ "from volara.blockwise import AffAgglom, ExtractFrags, GraphMWS, Relabel\n", "from volara.datasets import Labels\n", "from volara.dbs import SQLite\n", "from volara.lut import LUT" ] }, { "cell_type": "markdown", "id": "e3003244", "metadata": {}, "source": [ "First lets define the graph and arrays we are going to use.\n", "\n", "because our graph is in an sql database, we need to define a schema with column names and types\n", "for node and edge attributes.\n", "For nodes: The defaults such as \"id\", \"position\", and \"size\" are already defined\n", "so we only need to define the additional attributes, in this case we have no additional node attributes.\n", "For edges: The defaults such as \"id\", \"u\", \"v\" are already defined, so we are only adding the additional\n", "attributes \"xy_aff\", \"z_aff\", and \"lr_aff\" for saving the mean affinities between fragments." ] }, { "cell_type": "code", "execution_count": null, "id": "6c164272", "metadata": {}, "outputs": [], "source": [ "fragments_graph = SQLite(\n", " path=\"sample_A+_20160601.zarr/fragments.db\", # type: ignore[arg-type]\n", " edge_attrs={\"xy_aff\": \"float\", \"z_aff\": \"float\", \"lr_aff\": \"float\"},\n", ")\n", "fragments_dataset = Labels(store=\"sample_A+_20160601.zarr/fragments\") # type: ignore[arg-type]\n", "segments_dataset = Labels(store=\"sample_A+_20160601.zarr/segments\") # type: ignore[arg-type]" ] }, { "cell_type": "markdown", "id": "2224ac43", "metadata": {}, "source": [ "Now we define the tasks with the parameters we want to use." ] }, { "cell_type": "code", "execution_count": null, "id": "0c1ad631", "metadata": {}, "outputs": [], "source": [ "\n", "block_size = raw.array(\"r\")._source_data.chunks\n", "\n", "# Generate fragments in blocks\n", "extract_frags = ExtractFrags(\n", " db=fragments_graph,\n", " affs_data=affs,\n", " frags_data=fragments_dataset,\n", " block_size=block_size,\n", " context=Coordinate(6, 12, 12),\n", " bias=[-0.6] + [-0.4] * 2 + [-0.6] * 2 + [-0.8] * 2,\n", " strides=(\n", " [Coordinate(1, 1, 1)] * 3\n", " + [Coordinate(1, 3, 3)] * 2 # We use larger strides for larger affinities\n", " + [Coordinate(1, 6, 6)] * 2 # This is to avoid excessive splitting\n", " ),\n", " randomized_strides=True, # converts strides to probabilities of sampling affinities (1/prod(stride))\n", " remove_debris=64, # remove excessively small fragments\n", " num_workers=4,\n", ")\n", "\n", "# Generate agglomerated edge scores between fragments via mean affinity accross all edges connecting two fragments\n", "aff_agglom = AffAgglom(\n", " db=fragments_graph,\n", " affs_data=affs,\n", " frags_data=fragments_dataset,\n", " block_size=block_size,\n", " context=Coordinate(3, 6, 6) * 1,\n", " scores={\n", " \"z_aff\": affs.neighborhood[0:1],\n", " \"xy_aff\": affs.neighborhood[1:3],\n", " \"lr_aff\": affs.neighborhood[3:],\n", " },\n", " num_workers=4,\n", ")\n", "\n", "# Run mutex watershed again, this time on the fragment graph with agglomerated edges\n", "# instead of the voxel graph of affinities\n", "lut = LUT(path=\"sample_A+_20160601.zarr/lut.npz\") # type: ignore[arg-type]\n", "total_roi = raw.array(\"r\").roi\n", "graph_mws = GraphMWS(\n", " db=fragments_graph,\n", " lut=lut,\n", " weights={\"xy_aff\": (1, -0.4), \"z_aff\": (1, -0.6), \"lr_aff\": (1, -0.6)},\n", " roi=total_roi,\n", ")\n", "\n", "# Relabel the fragments into segments\n", "relabel = Relabel(\n", " lut=lut,\n", " frags_data=fragments_dataset,\n", " seg_data=segments_dataset,\n", " block_size=block_size,\n", " num_workers=4,\n", ")\n", "\n", "pipeline = extract_frags + aff_agglom + graph_mws + relabel\n", "pipeline.run_blockwise(multiprocessing=True)" ] }, { "cell_type": "markdown", "id": "7290d4c9", "metadata": { "lines_to_next_cell": 2 }, "source": [ "Let's visualize\n", "\n", "If you are following through on your own, I highly recommend installing `funlib.show.neuroglancer`, and\n", "running the command line tool via `neuroglancer -d sample_A+_20160601.zarr/*` to visualize the results in\n", "neuroglancer.\n", "\n", "For the purposes of visualizing here, we will make a simple gif" ] }, { "cell_type": "code", "execution_count": null, "id": "5f6bf10c", "metadata": {}, "outputs": [], "source": [ "import matplotlib.animation as animation\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from matplotlib.colors import ListedColormap\n", "\n", "fragments = fragments_dataset.array(\"r\")[:, ::2, ::2]\n", "segments = segments_dataset.array(\"r\")[:, ::2, ::2]\n", "raw_data = raw.array(\"r\")[:, ::2, ::2]\n", "\n", "# Get unique labels\n", "unique_labels = set(np.unique(fragments)) | set(np.unique(segments))\n", "num_labels = len(unique_labels)\n", "\n", "\n", "def random_color(label):\n", " rs = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(label)))\n", " return np.array((rs.random(), rs.random(), rs.random()))\n", "\n", "\n", "# Generate random colors for each label\n", "random_fragment_colors = [random_color(label) for label in range(num_labels)]\n", "\n", "# Create a colormap\n", "cmap_labels = ListedColormap(random_fragment_colors)\n", "\n", "# Map labels to indices for the colormap\n", "label_to_index = {label: i for i, label in enumerate(unique_labels)}\n", "indexed_fragments = np.vectorize(label_to_index.get)(fragments)\n", "indexed_segments = np.vectorize(label_to_index.get)(segments)\n", "\n", "fig, axes = plt.subplots(1, 3, figsize=(18, 8))\n", "\n", "ims = []\n", "for i, (raw_slice, fragments_slice, segments_slice) in enumerate(\n", " zip(raw_data, indexed_fragments, indexed_segments)\n", "):\n", " # Show the raw data\n", " if i == 0:\n", " im_raw = axes[0].imshow(raw_slice, cmap=\"gray\")\n", " axes[0].set_title(\"Raw\")\n", " im_fragments = axes[1].imshow(\n", " fragments_slice,\n", " cmap=cmap_labels,\n", " vmin=0,\n", " vmax=num_labels,\n", " interpolation=\"none\",\n", " )\n", " axes[1].set_title(\"Fragments\")\n", " im_segments = axes[2].imshow(\n", " segments_slice,\n", " cmap=cmap_labels,\n", " vmin=0,\n", " vmax=num_labels,\n", " interpolation=\"none\",\n", " )\n", " axes[2].set_title(\"Segments\")\n", " else:\n", " im_raw = axes[0].imshow(raw_slice, animated=True, cmap=\"gray\")\n", " im_fragments = axes[1].imshow(\n", " fragments_slice,\n", " cmap=cmap_labels,\n", " vmin=0,\n", " vmax=num_labels,\n", " interpolation=\"none\",\n", " animated=True,\n", " )\n", " im_segments = axes[2].imshow(\n", " segments_slice,\n", " cmap=cmap_labels,\n", " vmin=0,\n", " vmax=num_labels,\n", " interpolation=\"none\",\n", " animated=True,\n", " )\n", " ims.append([im_raw, im_fragments, im_segments])\n", "\n", "ims = ims + ims[::-1]\n", "ani = animation.ArtistAnimation(fig, ims, blit=True)\n", "ani.save(\"_static/cremi/segmentation.gif\", writer=\"pillow\", fps=10)\n", "plt.close()" ] }, { "cell_type": "markdown", "id": "19ac46cb", "metadata": {}, "source": [ "The final segmentation is shown below. Obviously this is not a great segmentation, but it is\n", "reasonably good for a model small enough to process a CREMI dataset in 20 minutes on a github\n", "action.\n", "![segmentation](_static/cremi/segmentation.gif)" ] } ], "metadata": { "jupytext": { "cell_metadata_filter": "-all", "main_language": "python", "notebook_metadata_filter": "-all" } }, "nbformat": 4, "nbformat_minor": 5 }