粒子画像のインスタンスセグメンテーションAI BPartIS を使ってみた

OpenCV

EM画像を対象にした粒子のインスタンスセグメンテーションモデルが下記論文で報告されていたので手持ちの画像で試してみました。

Bayesian Particle Instance Segmentation for Electron Microscopy Image QuantificationBatuhan Yildirim and Jacqueline M. ColeJournal of Chemical Information and Modeling 2021 61 (3), 1136-1149DOI: 10.1021/acs.jcim.0c01455

Just a moment...

提案モデルは、通常のインスタンスセグメンテーションモデルとは異なり、各画素が帰属する「粒子の重心座標」を予測するタスクに置き換えられている。これによってend-to-endでインスタンスセグメンテーションを実施できている(元ネタはこちらの論文:Instance Segmentation by Jointly Optimizing Spatial Embeddings and Clustering Bandwidth, arXiv:1906.11109)。

予測タスクを抽象化したことで、同じ重みのモデルで複数種類の材料画像に適用できるとのこと。

また、推論時にもドロップアウトを使用することで予測値を分布として取得し、その分散から予測の信頼性を評価して、誤検出を抑制している。

著者公開コード

著者らがgithubで公開しているコードは以下。

GitHub - by256/bpartis: Bayesian Particle Instance Segmentation for Electron Microscopy Image Quantification (JCIM 2021).
Bayesian Particle Instance Segmentation for Electron Microscopy Image Quantification (JCIM 2021). - GitHub - by256/bpartis: Bayesian Particle Instance Segmentat...
GitHub - by256/imagedataextractor: ImageDataExtractor 2.0 - a Python library for electron microscopy image quantification.
ImageDataExtractor 2.0 - a Python library for electron microscopy image quantification. - GitHub - by256/imagedataextractor: ImageDataExtractor 2.0 - a Python l...

今回は上記コードをベースにして、使いやすいように一部修正して動かした。

ディレクトリ構成

ディレクトリ構成は以下の通り。コードの詳細はページ下部にまとめた。

.
└── bpartis
    ├── models
    |   └seg-model.pt
    └── segment
        ├── model.py
        ├── nnmodules.py
        ├── cluster.py
        ├── uncertainty.py
        └── visualization.py

学習済み重み(seg-model.pt)は、以下を使用。

https://github.com/by256/imagedataextractor/blob/master/imagedataextractor/models/seg-model.pt

実際に動かしてみた

ライブラリをインポート。

import os
import numpy as np
import cv2
import matplotlib.pyplot as plt

from bpartis.segment.model import ParticleSegmenter
from bpartis.segment.visualization import vis_segmap

 

テスト用画像(test2imageモデルで生成した画像)を読み込む。この画像に映っている白濁粒子をきれいにセグメンテーションできるか試してみる。ちなみに、モデルはEM画像(グレースケール画像)で学習されているが、テストにはカラー画像を適用している。

image = cv2.imread('./test.png')
plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
plt.show()

 

モデルを読み込み、セグメンテーションを実施、その結果をインスタンスごとに塗り分けて表示した。その結果、きれいにセグメンテーションできていることが分かる。

segmenter = ParticleSegmenter(model_path='../models/seg-model.pt', device='cpu')
segmentation, uncertainty, _ = segmenter.segment(image)

os.makedirs('./results', exist_ok=True)
cv2.imwrite('./results/test_seg.png', segmentation)

seg_cl, seg_cl_concat = vis_segmap(segmentation, image=image)

fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(2,1,1)
ax.imshow(seg_cl)
ax = fig.add_subplot(2,1,2)
ax.imshow(seg_cl_concat)
plt.show()

 

モデルの出力の確認。入力画像と同じ形をしたndarrayで、画素値がインスタンスごとのidとなっている。

print('shape:',segmentation.shape)
print('labels:', np.unique(segmentation))
shape: (512, 512)
labels: [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31]

 

予測結果の分散から求めた、予測の不確かさを出力。粒子の領域は暗く(=不確かさが小さい)、粒子境界や、粒子が無い領域で明るく(=不確かさが大きい)なっていることが分かる。ベイズ推論の際は、この不確かさがしきい値以上の領域を除外するようにして誤検出を抑制している。

fig, ax = plt.subplots()
im = ax.imshow(uncertainty)
fig.colorbar(im, ax=ax)
plt.show()

 

GPUを使用する場合は、モデル読み込みの際の引数にdevice=’cuda’を入れる。

segmenter = ParticleSegmenter(model_path='../models/seg-model.pt',device='cuda')

 

また、ベイズ推論する必要が無い場合は、モデル読み込みの際の引数にbayesian=Falseを入れる。ベイズ推論時はドロップアウトを有効にして複数回(デフォルトでは30回)のフォワード処理をしているが、bayesian=Falseにするとフォワード処理が1回のみになるので処理時間が短くなるメリットがある。一方で、予測結果の不確かさを考慮できないため、誤検出が増えるデメリットがある。

今回のテスト画像でも、粒子の無い領域に誤検出がみられる。

segmenter = ParticleSegmenter(bayesian=False, model_path='../models/seg-model.pt',device='cuda')
segmentation, uncertainty, _ = segmenter.segment(image)

fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(2,1,1)
ax.imshow(seg_cl)
ax = fig.add_subplot(2,1,2)
ax.imshow(seg_cl_concat)
plt.show()

コード詳細

githubはこちら

GitHub - shashashanki/bpartis_test
Contribute to shashashanki/bpartis_test development by creating an account on GitHub.
import os
import torch
import numpy as np
from PIL import Image

from .cluster import Cluster
from .nnmodules import BranchedERFNet
from .uncertainty import expected_entropy, predictive_entropy, uncertainty_filtering


# model_pathを指定できるように追記
class ParticleSegmenter:

    def __init__(self, bayesian=True, n_samples=30, tu=0.0125, model_path='../models/seg-model.pt', device='cpu'):
        self.bayesian = bayesian
        self.n_samples = n_samples
        self.tu = tu
        self.seg_model = BranchedERFNet(num_classes=[4, 1]).to(device).eval()
        self.model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), model_path)
        self.seg_model.load_state_dict(torch.load(self.model_path, map_location=device))
        self.cluster = Cluster(n_sigma=2, device=device)
        if device is None:
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.device = device
        
    def preprocess(self, image):
        """Pre-process image for segmentation model."""
        image = Image.fromarray(image)
        image = image.resize((512, 512), resample=Image.BICUBIC)
        image = np.array(image)
        image = image / 255.0
        return image

    def postprocess_pred(self, image, h, w):
        """Post-process output segmentation map. Return output to the original input size."""
        image = Image.fromarray(image)
        image = image.resize((w, h), resample=Image.NEAREST)
        return np.array(image)

    def postprocess_uncertainty(self, image, h, w):
        """
        Resize uncertainty map. This is strictly for visualisation purposes.
        The output of this function will not be used for anything other 
        than visualisation.
        """
        image = Image.fromarray(image)
        image = image.resize((w, h), resample=Image.BICUBIC)
        return np.array(image)

    def enable_eval_dropout(self):
        """Enables dropout in eval mode for Bayesian inference via Monte Carlo dropout."""
        for module in self.seg_model.modules():
            if 'Dropout' in type(module).__name__:
                module.train()

    def monte_carlo_predict(self, image):
        """Performs Bayesian inference and computes epistemic uncertainty."""
        h, w = image.shape[-2:]
        cluster = Cluster(n_sigma=2, h=h, w=w, device=self.device)
        self.enable_eval_dropout()

        # get monte carlo model samples
        mc_outputs = []
        mc_seed_maps = []
        for i in range(self.n_samples):
            output = self.seg_model(image).detach()
            seed_map = torch.sigmoid(output[0, -1]).unsqueeze(0)  # \phi_{k}(e_{i})
            mc_outputs.append(output)
            mc_seed_maps.append(seed_map)

        mc_outputs = torch.cat(mc_outputs, dim=0)
        mc_seed_maps = torch.cat(mc_seed_maps, dim=0)

        # MC prediction (cluster the mean of MC samples)
        mc_prediction, _ = cluster.cluster(mc_outputs.mean(dim=0))

        # Uncertainty
        total = predictive_entropy(mc_seed_maps)
        aleatoric = expected_entropy(mc_seed_maps)
        epistemic = total - aleatoric  # $MI(y, \theta | x)$

        return mc_prediction, epistemic

    def segment(self, image):
        """Main segmentation routine."""
        o_h, o_w = image.shape[:2]
        image = self.preprocess(image)
        image = torch.FloatTensor(image).permute(2, 0, 1).unsqueeze(0).to(self.device)
        if self.bayesian:
            # monte carlo predict
            pred, uncertainty = self.monte_carlo_predict(image)
            original = pred.cpu().numpy().copy()
            original = self.postprocess_pred(original, o_h, o_w)
            pred = uncertainty_filtering(pred, uncertainty, tu=self.tu)
            pred = pred.cpu().numpy()
            uncertainty = uncertainty.cpu().numpy()
            # post-process uncertainty for visualisation
            uncertainty = self.postprocess_uncertainty(uncertainty, o_h, o_w)
        else:
            model_out = self.seg_model(image)[0].detach()
            pred = self.cluster.cluster(model_out)[0].cpu().numpy()
            uncertainty = None
            original = None
        pred = self.postprocess_pred(pred, o_h, o_w)
        return pred, uncertainty, original
import torch
import torch.nn as nn
import torch.nn.functional as F


class DownsamplerBlock (nn.Module):
    def __init__(self, ninput, noutput):
        super().__init__()

        self.conv = nn.Conv2d(ninput, noutput-ninput,
                              (3, 3), stride=2, padding=1, bias=True)
        self.pool = nn.MaxPool2d(2, stride=2)
        self.bn = nn.BatchNorm2d(noutput, eps=1e-3)

    def forward(self, input):
        output = torch.cat([self.conv(input), self.pool(input)], 1)
        output = self.bn(output)
        return F.elu(output)


class non_bottleneck_1d (nn.Module):
    def __init__(self, chann, dropprob, dilated):
        super().__init__()

        self.conv3x1_1 = nn.Conv2d(
            chann, chann, (3, 1), stride=1, padding=(1, 0), bias=True)

        self.conv1x3_1 = nn.Conv2d(
            chann, chann, (1, 3), stride=1, padding=(0, 1), bias=True)

        self.bn1 = nn.BatchNorm2d(chann, eps=1e-03)

        self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(
            1*dilated, 0), bias=True, dilation=(dilated, 1))

        self.conv1x3_2 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(
            0, 1*dilated), bias=True, dilation=(1, dilated))

        self.bn2 = nn.BatchNorm2d(chann, eps=1e-03)

        self.dropout = nn.Dropout2d(dropprob)

    def forward(self, input):

        output = self.conv3x1_1(input)
        output = F.elu(output)
        output = self.conv1x3_1(output)
        output = self.bn1(output)
        output = F.elu(output)

        output = self.conv3x1_2(output)
        output = F.elu(output)
        output = self.conv1x3_2(output)
        output = self.bn2(output)

        if (self.dropout.p != 0):
            output = self.dropout(output)

        return F.elu(output+input)


class Encoder(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.initial_block = DownsamplerBlock(3, 16)

        self.layers = nn.ModuleList()

        self.layers.append(DownsamplerBlock(16, 64))

        for x in range(0, 5):  # 5 times
            self.layers.append(non_bottleneck_1d(64, 0.03, 1))

        self.layers.append(DownsamplerBlock(64, 128))

        for x in range(0, 2):  # 2 times
            self.layers.append(non_bottleneck_1d(128, 0.3, 2))
            self.layers.append(non_bottleneck_1d(128, 0.3, 4))
            self.layers.append(non_bottleneck_1d(128, 0.3, 8))
            self.layers.append(non_bottleneck_1d(128, 0.3, 16))

        # Only in encoder mode:
        self.output_conv = nn.Conv2d(
            128, num_classes, 1, stride=1, padding=0, bias=True)

    def forward(self, input, predict=False):
        output = self.initial_block(input)

        for layer in self.layers:
            output = layer(output)

        if predict:
            output = self.output_conv(output)

        return output


class UpsamplerBlock (nn.Module):
    def __init__(self, ninput, noutput):
        super().__init__()
        self.conv = nn.Conv2d(ninput, noutput, 3, stride=1, padding=1, bias=True)
        self.bn = nn.BatchNorm2d(noutput, eps=1e-3)

    def forward(self, input):
        output = F.interpolate(input, scale_factor=2, mode='bilinear', align_corners=False)
        output = self.conv(output)
        output = self.bn(output)
        return F.elu(output)


class Decoder (nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.layers = nn.ModuleList()

        self.layers.append(UpsamplerBlock(128, 64))
        self.layers.append(non_bottleneck_1d(64, 0, 1))
        self.layers.append(non_bottleneck_1d(64, 0, 1))

        self.layers.append(UpsamplerBlock(64, 16))
        self.layers.append(non_bottleneck_1d(16, 0, 1))
        self.layers.append(non_bottleneck_1d(16, 0, 1))

        self.output_conv = nn.Conv2d(16, num_classes, 3, stride=1, padding=1, bias=True)

    def forward(self, input):
        output = input

        for layer in self.layers:
            output = layer(output)

        output = F.interpolate(output, scale_factor=2, mode='bilinear', align_corners=False)
        output = self.output_conv(output)

        return output


# ERFNet

class Net(nn.Module):
    def __init__(self, num_classes, encoder=None):  # use encoder to pass pretrained encoder
        super().__init__()

        if (encoder == None):
            self.encoder = Encoder(num_classes)
        else:
            self.encoder = encoder
        self.decoder = Decoder(num_classes)

    def forward(self, input, only_encode=False):
        if only_encode:
            return self.encoder.forward(input, predict=True)
        else:
            output = self.encoder(input)  # predict=False by default
            return self.decoder.forward(output)


class BranchedERFNet(nn.Module):
    def __init__(self, num_classes, encoder=None):
        super().__init__()

        if (encoder is None):
            self.encoder = Encoder(sum(num_classes))
        else:
            self.encoder = encoder

        self.decoders = nn.ModuleList()
        for n in num_classes:
            self.decoders.append(Decoder(n))

    def init_output(self, n_sigma=1):
        with torch.no_grad():
            output_conv = self.decoders[0].output_conv
            print('initialize last layer with size: ',
                  output_conv.weight.size())

            output_conv.weight[:, 0:2, :, :].fill_(0)
            output_conv.bias[0:2].fill_(0)

            output_conv.weight[:, 2:2+n_sigma, :, :].fill_(0)
            output_conv.bias[2:2+n_sigma].fill_(1)

    def forward(self, input, only_encode=False):
        if only_encode:
            return self.encoder.forward(input, predict=True)
        else:
            output = self.encoder(input)

        return torch.cat([decoder.forward(output) for decoder in self.decoders], 1)
import torch


class Cluster:

    def __init__(self, n_sigma=2, h=512, w=512, device='cuda'):
        self.n_sigma = n_sigma
        self.device = device
        xm = torch.linspace(0, 1, w).view(1, 1, -1).expand(1, h, w)
        ym = torch.linspace(0, 1, h).view(1, -1, 1).expand(1, h, w)
        xym = torch.cat((xm, ym), 0)
        self.xym = xym.to(self.device)

    def cluster(self, prediction, threshold=0.5):

        height, width = prediction.size(1), prediction.size(2)
        xym_s = self.xym[:, 0:height, 0:width]
        
        spatial_emb = torch.tanh(prediction[0:2]) + xym_s  # 2 x h x w
        sigma = prediction[2:2+self.n_sigma]  # n_sigma x h x w
        seed_map = torch.sigmoid(prediction[2+self.n_sigma:2+self.n_sigma + 1])  # 1 x h x w
       
        instance_map = torch.zeros(height, width).short()
        instances = []

        count = 1
        mask = (seed_map > 0.5).bool()

        if mask.sum() > 128:

            spatial_emb_masked = spatial_emb[mask.expand_as(spatial_emb)].view(2, -1)
            sigma_masked = sigma[mask.expand_as(sigma)].view(self.n_sigma, -1)
            seed_map_masked = seed_map[mask].view(1, -1)

            unclustered = torch.ones(mask.sum()).short().to(self.device)
            instance_map_masked = torch.zeros(mask.sum()).short().to(self.device)

            while(unclustered.sum() > 128):

                seed = (seed_map_masked * unclustered.float()).argmax().item()
                seed_score = (seed_map_masked * unclustered.float()).max().item()
                if seed_score < threshold:
                    break
                center = spatial_emb_masked[:, seed:seed+1]
                unclustered[seed] = 0
                s = torch.exp(sigma_masked[:, seed:seed+1]*10)
                dist = torch.exp(-1*torch.sum(torch.pow(spatial_emb_masked -
                                                        center, 2)*s, 0, keepdim=True))

                proposal = (dist > 0.5).squeeze()

                if proposal.sum() > 128:
                    if unclustered[proposal].sum().float()/proposal.sum().float() > 0.5:
                        instance_map_masked[proposal.squeeze()] = count
                        instance_mask = torch.zeros(height, width).short()
                        instance_mask[mask.squeeze().cpu()] = proposal.cpu().short()
                        instances.append(
                            {'mask': instance_mask.squeeze()*255, 'score': seed_score})
                        count += 1

                unclustered[proposal] = 0

            instance_map[mask.squeeze().cpu()] = instance_map_masked.cpu()

        return instance_map, instances
import torch


def entropy(p, eps=1e-6):
    p = torch.clamp(p, eps, 1.0-eps)
    return -1.0*((p*torch.log(p)) + ((1.0-p)*(torch.log(1.0-p))))

def expected_entropy(mc_preds):
    """Aleatoric (data) uncertainty"""
    return torch.mean(entropy(mc_preds), dim=0)

def predictive_entropy(mc_preds):
    """Total uncertainty"""
    return entropy(torch.mean(mc_preds, dim=0))

def uncertainty_filtering(prediction, uncertainty, tu=0.0125):
    """Filters instance segmentaton predictions based on their uncertainty."""

    filtered_pred = torch.zeros_like(prediction)

    for inst_id in torch.unique(prediction):
        if inst_id == 0:
            continue
        inst_mask = prediction == inst_id
        inst_uncertainty = torch.mean(uncertainty[inst_mask])
        if inst_uncertainty < tu:
            filtered_pred[inst_mask] = torch.max(filtered_pred) + 1

    return filtered_pred

予測結果を描画するメソッドを新たに作成。

import numpy as np
import cv2
import matplotlib.pyplot as plt


def vis_segmap(segmentation, cm_name='gist_ncar', image=[]):
    cm = plt.get_cmap(cm_name)

    x,y = segmentation.shape
    seg_cl = np.zeros((x,y,3))

    for i in list(np.unique(segmentation)):
        # 背景はスキップ
        if i==0:
            continue

        cl_num = i/23
        cl_num = cl_num - int(cl_num)

        seg_cl[segmentation==i] = cm(cl_num)[:3]

    if image==[]:
        return seg_cl
    else:
        seg_cl_concat = (cv2.cvtColor(image, cv2.COLOR_BGR2RGB)/255 + seg_cl)*0.5
        return seg_cl, seg_cl_concat

まとめ

今回は、粒子画像のインスタンスセグメンテーションモデルであるBPartISを試した。未学習のテスト画像でもしっかりセグメンテーションできることが確認できた。

コメント

タイトルとURLをコピーしました