{ "cells": [ { "cell_type": "markdown", "id": "a7c076ee", "metadata": {}, "source": [ "# Cremi example\n", "This example shows how to use volara to predict LSDs and affinities on the cremi dataset, and then run mutex watershed on the predicted affinities." ] }, { "cell_type": "code", "execution_count": null, "id": "c9530bbf", "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "import wget\n", "from funlib.geometry import Coordinate\n", "\n", "Path(\"_static/cremi\").mkdir(parents=True, exist_ok=True)\n", "\n", "# Download some cremi data\n", "# immediately convert it to zarr for convenience\n", "if not Path(\"sample_A+_20160601.zarr\").exists():\n", " wget.download(\n", " \"https://cremi.org/static/data/sample_A+_20160601.hdf\", \"sample_A+_20160601.hdf\"\n", " )\n", "if not Path(\"sample_A+_20160601.zarr/raw\").exists():\n", " import h5py\n", " import zarr\n", "\n", " raw_ds = zarr.open(\"sample_A+_20160601.zarr\", \"w\").create_dataset(\n", " \"raw\", data=h5py.File(\"sample_A+_20160601.hdf\", \"r\")[\"volumes/raw\"][:]\n", " )\n", " raw_ds.attrs[\"voxel_size\"] = (40, 4, 4)\n", " raw_ds.attrs[\"axis_names\"] = [\"z\", \"y\", \"x\"]\n", " raw_ds.attrs[\"unit\"] = [\"nm\", \"nm\", \"nm\"]" ] }, { "cell_type": "markdown", "id": "6d6762f9", "metadata": {}, "source": [ "Now we can predict the LSDs and affinities for this dataset. We have provided a very simple\n", "pretrained model for this dataset. We went for speed and efficiency over accuracy for this\n", "model so that it can run in a github action. You can train a significantly better model\n", "with access to a GPU and more Memmory." ] }, { "cell_type": "code", "execution_count": null, "id": "f4a85b40", "metadata": { "lines_to_next_cell": 2 }, "outputs": [], "source": [ "# Here are some important details about the model:\n", "\n", "# The number of output channels of our model. 10 lsds, 7 affinities\n", "out_channels = [10, 7]\n", "\n", "# The input shape of our model (not including channels)\n", "min_input_shape = Coordinate(36, 252, 252)\n", "\n", "# The output shape of our model (not including channels)\n", "min_output_shape = Coordinate(32, 160, 160)\n", "\n", "# The minimum increment for adjusting the input shape\n", "min_step_shape = Coordinate(1, 1, 1)\n", "\n", "# The range of predicted values. We have a sigmoid activation on our model\n", "out_range = (0, 1)\n", "\n", "# How much to grow the input shape for prediction. This is usually adjusted to maximize GPU memory,\n", "# but depends on how you saved your model. The model we provided does not support different\n", "# input shapes.\n", "pred_size_growth = Coordinate(0, 0, 0)" ] }, { "cell_type": "code", "execution_count": null, "id": "4f958a69", "metadata": {}, "outputs": [], "source": [ "from volara.datasets import Affs, Raw\n", "from volara_torch.blockwise import Predict\n", "from volara_torch.models import TorchModel" ] }, { "cell_type": "markdown", "id": "e183fd2a", "metadata": {}, "source": [ "\n", "First we define the datasets that we are using along with some basic information about them" ] }, { "cell_type": "code", "execution_count": null, "id": "6a533e48", "metadata": {}, "outputs": [], "source": [ "# our raw data is stored in uint8, but our model expects floats in range (0, 1) so we scale it\n", "raw_dataset = Raw(store=\"sample_A+_20160601.zarr/raw\", scale_shift=(1 / 255, 0))\n", "# The affinities neighborhood depends on the model that was trained. Here we learned long range xy affinities\n", "affs_dataset = Affs(\n", " store=\"sample_A+_20160601.zarr/affs\",\n", " neighborhood=[\n", " Coordinate(1, 0, 0),\n", " Coordinate(0, 1, 0),\n", " Coordinate(0, 0, 1),\n", " Coordinate(0, 6, 0),\n", " Coordinate(0, 0, 6),\n", " Coordinate(0, 18, 0),\n", " Coordinate(0, 0, 18),\n", " ],\n", ")\n", "# We are just storing the lsds in a simple zarr dataset using the same format as the raw data\n", "lsds_dataset = Raw(store=\"sample_A+_20160601.zarr/lsds\")" ] }, { "cell_type": "markdown", "id": "9dbc0721", "metadata": {}, "source": [ "Now we can define our model with the parameters we defined above. We will use the\n", "`TorchModel` class to load the model from a checkpoint and pass it to the `Predict` class." ] }, { "cell_type": "code", "execution_count": null, "id": "99adc48a", "metadata": {}, "outputs": [], "source": [ "torch_model = TorchModel(\n", " save_path=\"checkpoint_data/model.pt\",\n", " checkpoint_file=\"checkpoint_data/model_checkpoint_15000\",\n", " in_channels=1,\n", " out_channels=out_channels,\n", " min_input_shape=min_input_shape,\n", " min_output_shape=min_output_shape,\n", " min_step_shape=min_step_shape,\n", " out_range=out_range,\n", " pred_size_growth=pred_size_growth,\n", ")\n", "predict_cremi = Predict(\n", " checkpoint=torch_model,\n", " in_data=raw_dataset,\n", " out_data=[lsds_dataset, affs_dataset],\n", ")\n", "\n", "predict_cremi.run_blockwise(multiprocessing=False)" ] }, { "cell_type": "markdown", "id": "ba58113b", "metadata": {}, "source": [ "Let's visualize the results" ] }, { "cell_type": "code", "execution_count": null, "id": "f3e7f22d", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import matplotlib.animation as animation\n", "\n", "\n", "fig, axes = plt.subplots(1, 3, figsize=(14, 8))\n", "\n", "ims = []\n", "for i, (raw_slice, affs_slice, lsd_slice) in enumerate(\n", " zip(\n", " raw_dataset.array(\"r\")[:],\n", " affs_dataset.array(\"r\")[:].transpose([1, 0, 2, 3]),\n", " lsds_dataset.array(\"r\")[:].transpose([1, 0, 2, 3]),\n", " )\n", "):\n", " # Show the raw data\n", " if i == 0:\n", " im_raw = axes[0].imshow(raw_slice, cmap=\"gray\")\n", " axes[0].set_title(\"Raw\")\n", " im_affs_long = axes[1].imshow(\n", " affs_slice[[0, 5, 6]].transpose([1, 2, 0]),\n", " vmin=0,\n", " vmax=255,\n", " interpolation=\"none\",\n", " )\n", " axes[1].set_title(\"Affs (0, 5, 6)\")\n", " im_lsd = axes[2].imshow(\n", " lsd_slice[:3].transpose([1, 2, 0]),\n", " vmin=0,\n", " vmax=255,\n", " interpolation=\"none\",\n", " )\n", " axes[2].set_title(\"LSDs (0, 1, 2)\")\n", " else:\n", " im_raw = axes[0].imshow(raw_slice, cmap=\"gray\", animated=True)\n", " axes[0].set_title(\"Raw\")\n", " im_affs_long = axes[1].imshow(\n", " affs_slice[[0, 5, 6]].transpose([1, 2, 0]),\n", " vmin=0,\n", " vmax=255,\n", " interpolation=\"none\",\n", " animated=True,\n", " )\n", " axes[1].set_title(\"Affs (0, 5, 6)\")\n", " im_lsd = axes[2].imshow(\n", " lsd_slice[:3].transpose([1, 2, 0]),\n", " vmin=0,\n", " vmax=255,\n", " interpolation=\"none\",\n", " animated=True,\n", " )\n", " axes[2].set_title(\"LSDs (0, 1, 2)\")\n", " ims.append([im_raw, im_affs_long, im_lsd])\n", "\n", "ims = ims + ims[::-1]\n", "ani = animation.ArtistAnimation(fig, ims, blit=True)\n", "ani.save(\"_static/cremi/outputs.gif\", writer=\"pillow\", fps=10)\n", "plt.close()" ] }, { "cell_type": "markdown", "id": "2e99808d", "metadata": {}, "source": [ "![predictions](_static/cremi/outputs.gif)" ] } ], "metadata": { "jupytext": { "cell_metadata_filter": "-all", "main_language": "python", "notebook_metadata_filter": "-all" } }, "nbformat": 4, "nbformat_minor": 5 }