Cremi example

This example shows how to use volara to predict LSDs and affinities on the cremi dataset, and then run mutex watershed on the predicted affinities.

[1]:
from pathlib import Path

import wget
from funlib.geometry import Coordinate

Path("_static/cremi").mkdir(parents=True, exist_ok=True)

# Download some cremi data
# immediately convert it to zarr for convenience
if not Path("sample_A+_20160601.zarr").exists():
    wget.download(
        "https://cremi.org/static/data/sample_A+_20160601.hdf", "sample_A+_20160601.hdf"
    )
if not Path("sample_A+_20160601.zarr/raw").exists():
    import h5py
    import zarr

    raw_ds = zarr.open("sample_A+_20160601.zarr", "w").create_dataset(
        "raw", data=h5py.File("sample_A+_20160601.hdf", "r")["volumes/raw"][:]
    )
    raw_ds.attrs["voxel_size"] = (40, 4, 4)
    raw_ds.attrs["axis_names"] = ["z", "y", "x"]
    raw_ds.attrs["unit"] = ["nm", "nm", "nm"]

Now we can predict the LSDs and affinities for this dataset. We have provided a very simple pretrained model for this dataset. We went for speed and efficiency over accuracy for this model so that it can run in a github action. You can train a significantly better model with access to a GPU and more Memmory.

[2]:
# Here are some important details about the model:

# The number of output channels of our model. 10 lsds, 7 affinities
out_channels = [10, 7]

# The input shape of our model (not including channels)
min_input_shape = Coordinate(36, 252, 252)

# The output shape of our model (not including channels)
min_output_shape = Coordinate(32, 160, 160)

# The minimum increment for adjusting the input shape
min_step_shape = Coordinate(1, 1, 1)

# The range of predicted values. We have a sigmoid activation on our model
out_range = (0, 1)

# How much to grow the input shape for prediction. This is usually adjusted to maximize GPU memory,
# but depends on how you saved your model. The model we provided does not support different
# input shapes.
pred_size_growth = Coordinate(0, 0, 0)
[3]:
from volara.datasets import Affs, Raw
from volara_torch.blockwise import Predict
from volara_torch.models import TorchModel
/home/runner/work/volara-torch/volara-torch/.venv/lib/python3.12/site-packages/gunpowder/nodes/rasterize_graph.py:23: SyntaxWarning: invalid escape sequence '\e'
  """Data structure to store parameters for rasterization of graph.

First we define the datasets that we are using along with some basic information about them

[4]:
# our raw data is stored in uint8, but our model expects floats in range (0, 1) so we scale it
raw_dataset = Raw(store="sample_A+_20160601.zarr/raw", scale_shift=(1 / 255, 0))
# The affinities neighborhood depends on the model that was trained. Here we learned long range xy affinities
affs_dataset = Affs(
    store="sample_A+_20160601.zarr/affs",
    neighborhood=[
        Coordinate(1, 0, 0),
        Coordinate(0, 1, 0),
        Coordinate(0, 0, 1),
        Coordinate(0, 6, 0),
        Coordinate(0, 0, 6),
        Coordinate(0, 18, 0),
        Coordinate(0, 0, 18),
    ],
)
# We are just storing the lsds in a simple zarr dataset using the same format as the raw data
lsds_dataset = Raw(store="sample_A+_20160601.zarr/lsds")

Now we can define our model with the parameters we defined above. We will use the TorchModel class to load the model from a checkpoint and pass it to the Predict class.

[5]:
torch_model = TorchModel(
    save_path="checkpoint_data/model.pt",
    checkpoint_file="checkpoint_data/model_checkpoint_15000",
    in_channels=1,
    out_channels=out_channels,
    min_input_shape=min_input_shape,
    min_output_shape=min_output_shape,
    min_step_shape=min_step_shape,
    out_range=out_range,
    pred_size_growth=pred_size_growth,
)
predict_cremi = Predict(
    checkpoint=torch_model,
    in_data=raw_dataset,
    out_data=[lsds_dataset, affs_dataset],
)

predict_cremi.run_blockwise(multiprocessing=False)
ERROR:daisy.context:DAISY_CONTEXT environment variable not found!
Starting prediction...
/home/runner/work/volara-torch/volara-torch/.venv/lib/python3.12/site-packages/torch/serialization.py:1488: UserWarning: 'torch.load' received a zip file that looks like a TorchScript archive dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to silence this warning)
  warnings.warn(

Execution Summary
-----------------

  Task lsds-affs-predict:

    num blocks : 256
    completed ✔: 256 (skipped 0)
    failed    ✗: 0
    orphaned  ∅: 0

    all blocks processed successfully
[5]:
defaultdict(daisy.task_state.TaskState,
            {'lsds-affs-predict': Started: True
             Total Blocks: 256
             Ready: 0
             Processing: 0
             Pending: 0
             Completed: 256
             Skipped: 0
             Failed: 0
             Orphaned: 0})

Let’s visualize the results

[6]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation


fig, axes = plt.subplots(1, 3, figsize=(14, 8))

ims = []
for i, (raw_slice, affs_slice, lsd_slice) in enumerate(
    zip(
        raw_dataset.array("r")[:],
        affs_dataset.array("r")[:].transpose([1, 0, 2, 3]),
        lsds_dataset.array("r")[:].transpose([1, 0, 2, 3]),
    )
):
    # Show the raw data
    if i == 0:
        im_raw = axes[0].imshow(raw_slice, cmap="gray")
        axes[0].set_title("Raw")
        im_affs_long = axes[1].imshow(
            affs_slice[[0, 5, 6]].transpose([1, 2, 0]),
            vmin=0,
            vmax=255,
            interpolation="none",
        )
        axes[1].set_title("Affs (0, 5, 6)")
        im_lsd = axes[2].imshow(
            lsd_slice[:3].transpose([1, 2, 0]),
            vmin=0,
            vmax=255,
            interpolation="none",
        )
        axes[2].set_title("LSDs (0, 1, 2)")
    else:
        im_raw = axes[0].imshow(raw_slice, cmap="gray", animated=True)
        axes[0].set_title("Raw")
        im_affs_long = axes[1].imshow(
            affs_slice[[0, 5, 6]].transpose([1, 2, 0]),
            vmin=0,
            vmax=255,
            interpolation="none",
            animated=True,
        )
        axes[1].set_title("Affs (0, 5, 6)")
        im_lsd = axes[2].imshow(
            lsd_slice[:3].transpose([1, 2, 0]),
            vmin=0,
            vmax=255,
            interpolation="none",
            animated=True,
        )
        axes[2].set_title("LSDs (0, 1, 2)")
    ims.append([im_raw, im_affs_long, im_lsd])

ims = ims + ims[::-1]
ani = animation.ArtistAnimation(fig, ims, blit=True)
ani.save("_static/cremi/outputs.gif", writer="pillow", fps=10)
plt.close()

predictions