A robust deep learning approach for segmenting cortical and trabecular bone from 3D high resolution µCT scans of mouse bone

Daily Zen Mews


Datasets

We evaluated the effectiveness of our deep learning segmentation architecture, DBAHNet, across various experimental studies (see Fig. 2). Our extensive dataset contains a total of 163 tibia scans derived from seven experimental studies14,15,16,37,38,39,40. These scans exhibit varied bone morphology due to differences in scanning resolutions, mouse strains, ages, drug treatments, surgical procedures, and mechanical loading. The dataset includes four mouse strains: C57BL/6, BALB/c, C57BL/6JOlaHsd, and homozygous oim, focusing on young and mature animals ranging from 8 to 24 weeks of age. These animals received a variety of treatments, including ovariectomy (OVX), human amniotic fluid stem cells (hAFSC), sciatic neurectomy (SN), risedronate (Ris), and parathyroid hormone (PTH) treatments at different doses. Additionally, some studies have applied mechanical loading (ML) to investigate the individual and combined effects of these treatments on bone structure. The mouse tibiae were imaged via µCT at different resolutions ranging from 4.8 µm to 13.7 µm. This high resolution enabled a detailed assessment of trabecular and cortical bone structures. The dataset covered various responses to drug interventions, with mechanical loading experiments designed to mimic physiological stress and explore bone adaptation responses.

Fig. 2
figure 2

Large, diversified collection of high-resolution µCT scans at the proximal region of interest of the mouse tibia, covering multiple experimental setups14,15,16,37,38,39,40. These setups include different treatments and conditions, such as parathyroid hormone (PTH), risedronate (Ris), sciatic neurectomy (SN), ovariectomy (OVX), human amniotic fluid stem cells (hAFSC), mechanical loading (ML), age (8 to 24 weeks), mouse strain (C57BL/6, BALB/c, and homozygous oim), and scanning resolution (4.8 µm to 13.7 µm).

We manually segmented the scans following standard guidelines10 by sampling sectional 2D slices with a fixed step tailored to the specific bone region. This involved manual segmentation of both the cortical and trabecular compartments at these specific cross-sectional slices. We subsequently employed the Interpolation Biomedisa platform11 for semiautomated segmentation, which uses weighted random walks for interpolation, and considers both the presegmented slices and the entire original volumetric image data. We performed postprocessing on the interpolated 3D labels to smooth and remove outliers, followed by visual inspection to validate the final ground truth labels.

The datasets used in this study are described as follows:

  • Control dataset14,15,16 This dataset is based on three separate studies and includes tibiae from C57BL/6 virgin female mice aged 19–22 weeks. It comprises 74 control tibiae that were not subjected to any treatments in the referenced preclinical experiments. High-resolution µCT scans were performed via SkyScan 1172 (SkyScan, Kontich, Belgium), with resolutions ranging from 4.8 µm to 5 µm (ex vivo).

  • Dataset 114 This study investigated the impact of the bone anabolic drug intermittent PTH on bone adaptation in virgin female C57BL/6 mice. The treatment doses used were 20, 40, and 80 µg/kg/day, both alone and in combination with ML. The dataset includes images of four groups: PTH 20 (N=6), PTH 20 + ML (N=6), PTH 40 (N=8), and PTH 80 (N=10), all aged 19 weeks. Images were captured at a resolution of 5 µm (ex vivo). Both mechanical loading and PTH treatment have anabolic effects on bone, promoting bone formation and increasing bone mass. Their combined effects result in more pronounced anabolic responses, further complicating segmentation due to increased bone remodeling and porosity, especially near the growth plate.

  • Dataset 215 This study examined the effects of the anticatabolic drug risedronate on bone adaptation in virgin female C57BL/6 mice. The dataset includes three risedronate dose groups (0.15, 1.5, and 15 µg/kg/day) with and without mechanical loading, each with \(N=5\) samples, and a risedronate 150 µg/kg/day group with one loaded and one nonloaded sample (N = 1), all aged 19 weeks. Images were captured at a resolution of 4.8 µm (ex vivo). This segmentation is challenging because of the effects of the anticatabolic risedronate and the anabolic effect of ML. Compared with the control, risedronate reduces bone resorption, resulting in greater trabecular bone volume and trabecular number, whereas ML increases trabecular and cortical.

  • Dataset 316 This study assessed the impact of mechanical loading on bone adaptation in C57BL/6 mice subjected to right sciatic neurectomy to minimize natural loading in their right tibiae. The dataset includes images of two groups, 4 N (N=5) and 8 N (N=5), aged 20 weeks. Images were captured at a resolution of 5 µm (ex vivo). The segmentation challenges arise from localized bone loss due to neurectomy and the subsequent anabolic bone changes induced by mechanical loading.

  • Dataset 437 This study provides high-resolution in vivo µCT images of tibiae from female C57BL/6 mice subjected to OVX, which mimics postmenopausal osteoporosis characterized by increased bone remodeling, followed by combined PTH (100 µg/kg/day) and ML interventions. The dataset includes wild-type (WT) female C57BL/6 OVX (N=4) mice recorded at weeks 14, 18, 20, and 22. Images were captured at a resolution of 10.4 µm (in vivo). OVX increases porosity and bone remodeling, presenting significant segmentation challenges. The combination of PTH and ML further complicates segmentation due to their anabolic effects, enhancing bone formation and altering bone architecture. Additionally, the changes in the resolution and age of the mice compared with those in the control dataset complicate the generalization of segmentation techniques.

  • Dataset 538 This study conducted high-resolution µCT analysis of bone microarchitecture in 8-week-old homozygous oim mice treated with human amniotic fluid stem cells (hAFSC). The dataset includes images of the Oim (N=3) and Oim + hAFSC (N=3) groups. Images were captured at a resolution of 5 µm (ex vivo). Osteogenesis imperfecta (OI) is characterized by severe characteristics, such as reduced size, skeletal fragility, frequent fractures, and abnormal bone microarchitecture, in OIM mice. Treatment with hAFSC improved bone strength, quality, and remodeling. The young age of the mice, combined with their deformed shape due to the nature of the mouse strain (homozygous oim) and hAFSC treatment effects, presents significant segmentation challenges. Their bones are not fully mature and are less dense, complicating the generalization of segmentation techniques from the control dataset of untreated mature bones.

  • Dataset 639 This study explored the impact of ovariectomy on bone structure and density in female C57BL/6 and BALB/c mice by comparing the WT and OVX groups. The dataset includes four groups: C57BL/6 WT (N=1), C57BL/6 OVX (N=1), BALB/c WT (N=1), and BALB/c OVX (N=1). Images were captured at a resolution of 10.4 µm (in vivo) at the age of 24 weeks. The differences in the structure of the strains (C57BL/6 and BALB/c), OVX bone loss effects, and lower resolution make it difficult to generalize segmentation techniques from the control dataset.

  • Dataset 740 This study focused on a murine model of osteoporosis in C57BL/6JOlaHsd OVX female mice. The dataset includes images of femurs from C57BL/6JOlaHsd female mice (N = 4) that underwent OVX at the age of 14 weeks. Images were captured at a resolution of 13.7 µm (ex vivo) at the age of 17 weeks. Compared with tibiae, the combination of femur bones, which have different structural characteristics, a much lower resolution of 13.7 \(\mu\)m, and OVX-induced bone loss presents substantial segmentation challenges.

This diverse dataset encompasses a wide range of conditions, including different ages, strains, resolutions, drug treatments, surgical procedures, and mechanical loading, providing a rich resource for a robust validation of our deep learning model. A summary of the datasets used in this study is presented in Table 7. The main datasets14,15,16 are our own extensive collections, which contain very high-resolution µCT scans at 5 µm. The secondary datasets37,38,39,40 consist of publicly available samples collected to test the segmentation under new unseen experimental conditions.

The datasets used in this study were obtained from independent experiments conducted by their respective institutions. For Dataset 114 and Dataset 215, all procedures complied with the UK Animals (Scientific Procedures) Act 1986, with ethical approval from the ethics committee of The Royal Veterinary College (London, UK). Dataset 316 was similarly approved by the ethics committee of the University of Bristol (Bristol, UK). Dataset 437 followed the ARRIVE guidelines and was approved by the local Research Ethics Committee of the University of Sheffield (Sheffield, UK). Dataset 538 was conducted under UK Home Office project licence PPL 70/6857, and Dataset 639 under project licence PPL 40/3499, both overseen by the University of Sheffield. Finally, Dataset 740 received approval from the Local Ethical Committee for Animal Research of the University of the Basque Country (UPV/EHU, ref M20/2019/176), adhering to European Directive 2010/63/EU and Spanish Law RD 53/2013. All original studies ensured compliance with relevant ethical guidelines, and our use of these datasets strictly followed their established approvals.

Table 7 Summary of the datasets used in this study. The control dataset is listed first (Underline), followed by the main datasets (Bold) and secondary datasets (Italic). Abbreviations: parathyroid hormone (PTH), mechanical loading (ML), risedronate (Ris), sciatic neurectomy (SN), ovariectomy (OVX), human amniotic fluid stem cells (hAFSC), wild type (WT). \(N\) represents the number of µCT mouse tibia scans.

The region of interest in the mouse tibia used in this research was cropped from the metaphysis, starting just below the growth plate (approximately 6–8% of the bone length from the proximal region, where trabecular bone is highly present and active) and extending to approximately 60–65% of the bone length. Additionally, the control scans were obtained from a slightly deeper region within the metaphysis, where trabecular bone is not excessively present in the medullary area. The reason for choosing this deeper region is to detect any potential trabecularization of the cortical bone or growth of the trabecular bone, which can occur under certain conditions such as with drug treatments or aging.

General segmentation pipeline

This section describes our automated, robust, deep learning-based pipeline developed for 3D high-resolution µCT segmentation, that specifically targets the cortical and trabecular compartments of the mouse tibia, as illustrated in Fig. 3. The general segmentation pipeline begins with preprocessing the raw 3D µCT scans via image processing techniques to isolate the mouse tibia. These preprocessed scans serve as the input for training the deep learning model. During training, data augmentation automatically expands the dataset, creating variations that improve model accuracy. The model is trained iteratively on the training set, with continuous monitoring of the validation mean Dice score and loss until convergence is achieved. After training, the model produces segmentation masks for both the cortical and the trabecular compartments. A postprocessing step further refines the segmentation to enhance the extraction of the cortical and trabecular bone. The detailed steps of the pipeline are outlined below.

Fig. 3
figure 3

The global pipeline for robust 3D high-resolution µCT mouse tibia segmentation. The pipeline includes preprocessing, data augmentation, segmentation via DBAHNet, and final postprocessing.

Preprocessing

We subjected the raw 3D µCT scans to a series of preprocessing steps to prepare the data for segmentation. Thresholding: We applied the Otsu thresholding algorithm43 to automatically separate the bone from the background, which includes experimental materials such as the sample holder and the resin. To ensure the retention of actual bone voxels, particularly for the trabecular bone, a threshold margin of \(M = 5\) was subtracted from the threshold obtained by the algorithm to maintain connectivity. Artifact removal: We eliminated any remaining noise by retaining the largest connected component, which represents the bone. These two preprocessing steps are crucial, as they not only clean the bone from the experimental background, allowing the model to focus on segmenting the cortical and trabecular compartments, but also significantly reduce the size of the µCT scans. Working with 3D µCT scans at very high resolution requires careful consideration of efficiency, as training complex deep learning architectures becomes computationally demanding with larger input data. Performing these steps substantially reduces the size of the input images. For instance, a raw scan of the full mouse tibia recorded at 5µm from Dataset 114 is approximately 2.4 GB. After background removal and autocropping, the file size is reduced to approximately 150 MB (both sizes are reported for the compressed Nifti format). Fibula removal: We removed the second-largest component, representing the fibula, at each cross-sectional slice along the \(z\)-axis. Normalization: We normalized the voxel values via z-score normalization, transforming the image intensities so that the resulting distribution has a mean of zero and a standard deviation of one. The z-score normalization is defined as \(Z = \frac{X – \mu }{\sigma }\), where \(Z\) is the normalized intensity value, \(X\) is the original intensity value, \(\mu\) is the mean intensity value of the image, and \(\sigma\) is the standard deviation of the intensity values of the image.

Data augmentation

To enhance the model’s generalization ability, we employed various data augmentation techniques applied to the original 3D scans during each batch generation throughout the training. Random affine transformations: We applied rotations and scaling to simulate changes in the orientation and scale of the bone relative to the scanner. The rotation range is \([0, \pi ]\) along the \(z\)-axis, and the scaling factor range is \(s \in [0.85, 1.25]\). 3D Elastic deformations: We introduced nonlinear distortions to mimic natural bone variability via the following formula: \(x’ = x + \alpha \cdot {\mathcal {G}}(\sigma )\), where \({\mathcal {G}}(\sigma )\) is a random Gaussian displacement field with a standard deviation \(\sigma \in [9, 13]\) and magnitude \(\alpha \in [0, 900]\). Random Gaussian Noise: We added random Gaussian noise to simulate varying scanner qualities. The noise addition is given by \(x’ = x + {\mathcal {N}}(0, \sigma ^2)\), where \({\mathcal {N}}(0, \sigma ^2)\) is Gaussian noise with zero mean and variance \(\sigma ^2 = 0.1\). Random intensity scaling: We scaled the intensity of the images to account for differences in imaging conditions. The intensity scaling is given by \(x’ = x \cdot (1 + f)\), where the scaling factor \(f\) ranges from \(-0.1\) to \(0.1\). Random contrast adjusting: We adjusted the contrast of the images to account for differences in imaging conditions. The contrast adjustment is expressed as \(x’ = x^{\gamma }\) with \(\gamma \in [0.5, 4.5]\). These transformations ensure the robustness and accuracy of the deep learning model by providing diverse and realistic variations in the training data. This approach generates new, artificially augmented data during training, where data augmentation is applied live to each batch with a small probability (\(p = 0.1\)), simulating scans under different experimental setups for the training of our deep learning model.

Segmentation

In this study, we employed a novel deep neural network architecture, DBAHNet (dual branch attention-based hybrid network), which was previously validated by comparing its performance with popular state-of-the-art architectures on the control dataset36. DBAHNet is specifically designed for high-resolution 3D µCT bone image segmentation, and focuses on the cortical and trabecular compartments. This architecture advances deep learning approaches by integrating both transformers and convolutional neural networks to effectively capture local features and long-range dependencies. The hybrid design of DBAHNet leverages the ability of convolutional layers for local feature analysis and the attention mechanism of transformers. In this work, we apply DBAHNet within a comprehensive pipeline to evaluate its robustness across various conditions and datasets, demonstrating its utility beyond the initial conference presentation. The complete architecture of DBAHNet is detailed in the subsequent sections.

Postprocessing

The final phase involved applying postprocessing techniques to increase the quality of the segmentation masks and mitigate the inherent imperfections in the segmentation process: Noise removal: We removed any segmentation noise and outliers by retaining the largest connected component. Transitional region smoothing: We used morphological opening filters to remove small openings at the endosteum surface of the cortical bone and assign them to the trabecular bone. The morphological opening filter is defined as: \(\text {Opening}(A, B) = (A \ominus B) \oplus B\), where \(A\) is the set of foreground voxels in the binary image, \(B\) is the structuring element (a sphere with radius \(K_o\)), \(\ominus\) denotes the erosion filter, which removes pixels from the boundaries of objects, eliminating small openings at the endosteum surface, and \(\oplus\) denotes the dilation filter, which adds pixels to the boundaries of objects, restoring the original size of the cortical surface while maintaining a smooth transition to the trabecular bone. The kernel value \(K_o\) is set to 3. Trabecular structure connectivity: We ensured the connectivity of the trabeculae for accurate morphometry in subsequent steps. For this, we perform Connected Component Analysis by identifying and labeling all connected components in the binary mask of the trabecular bone and merging components that are close to each other. Merging is performed via a morphological closing filter with a kernel radius \(R_c = 1\), corresponding to the minimum distance required to merge disconnected trabeculae. The morphological closing filter can be defined as follows: \(\text {Closing}(A, B) = (A \oplus B) \ominus B\), where \(A\) is the set of foreground voxels in the binary image, \(B\) is the structuring element (a sphere with radius \(R_c\)), \(\oplus\) denotes the dilation filter, which adds pixels to the boundaries of objects, potentially bridging small gaps caused by segmentation errors, and \(\ominus\) denotes the erosion filter, which removes pixels from the boundaries of objects, and restores the original object size while maintaining new connections. The different modules of the general segmentation pipeline facilitated the extraction and subsequent morphological analysis of both cortical and trabecular bone from three-dimensional µCT scans, enabling their visualization and assessment of their respective morphological parameters for preclinical skeletal studies.

Architecture of DBAHNet

The proposed architecture, the Dual-Branch Attention-based Hybrid Network (DBAHNet), features a dual-branch hybrid design that incorporates both convolutional neural networks (CNNs) and transformers in the encoder and decoder pathways (see Fig. 4A). The patch embedding block projects the 3D scan into an embedding space with \(C = 96\) channels via successive convolutions. This process results in a reduced-dimensionality space, defined by the reduction embedding vector \(E = [4, 4, 4]\), creating a patch embedding of size \((C, \frac{H}{4}, \frac{W}{4}, \frac{D}{4})\), where \(H\), \(W\), and \(D\) represent the height, width, and depth of the input 3D scan, respectively. This embedding serves as the input to both the transformer and convolutional branches, each consisting of three hierarchical levels.

In the encoder pathway, each level comprises two sequential Swin transformers blocks in the transformer branch and a Channel Attention-Based Convolution Module (CACM) in the convolution branch. The transformer branch uses 3D-adapted Swin transformers to process feature maps at multiple scales, capturing global long-range dependencies within the volume. Each transformer block consists of two layers; the first employs regular volume partitioning, whereas the second uses shifted partitioning to increase the connectivity between layers. In the convolution branch, the CACM enhances cross-channel interaction by concatenating the outputs of global average pooling and maximum pooling, followed by two GeLU-activated 3D convolutions to create an attention map. This map modulates the initial feature map through elementwise multiplication, and a final 3D convolution further encodes the output for subsequent layers.

The outputs from the transformer and convolution branches at each level are fused via the Transformer-Convolution Feature Fusion Module (TCFFM). The TCFFM performs downsampling in the encoder by applying channelwise average pooling to \(x_{\text {Tr}}\) and \(x_{\text {C}}\) (the feature maps from the transformer and convolution branches), followed by a sigmoid function to generate an attention mask that filters the channels. The results are then concatenated and encoded through a 3D convolution layer. After encoding, the resulting feature maps are downscaled to \((8C, \frac{H}{32}, \frac{W}{32}, \frac{D}{32})\) and passed to the bottleneck. The bottleneck consists of four global 3D transformer blocks that perform global attention over all the downsampled feature maps, aggregating information to provide a comprehensive representation for the decoder.

The decoder mirrors the encoder symmetrically. It uses the spatial attention-based convolution module (SACM) instead of the CACM to enhance relevant spatial features for focused reconstruction of the segmentation mask. The SACM applies max-pooling and average-pooling, concatenates the results, and uses a \(1 \times 1 \times 1\) convolution to create an attention map. This attention map modulates the input feature map, which is further processed by a final 3D convolution. The TCFFM module in the decoder performs upsampling, restoring the original volume size. Throughout the decoder, feature maps from all layers are filtered via attention gates and residual skip connections from the encoder. Finally, a transpose convolution reconstructs the segmentation masks. All internal components of DBAHNet are illustrated in Fig. 4B.

Fig. 4
figure 4

(a) Global architecture of the dual-branch attention-based hybrid network (DBAHNet) for 3D µCT mouse bone tibia imaging segmentation. (b) Diagram of all internal modules of DBAHNet, including the channel-wise attention-based convolution module (CACM), the spatial-wise attention-based convolution module (SACM), the transformer-convolution feature fusion module (TCFFM), the bottleneck, the patch embedding block, the attention gate, and the transformer block.

Transformer block

We leveraged a 3D adaptation of Swin transformers32, which perform self-attention within a local volume of feature maps at each hierarchical level to capture enriched contextual representations of the data. Each transformer unit consists of two consecutive transformers. The first transformer employs regular volume partitioning, whereas the second transformer introduces shifted local volume partitioning to ensure connectivity with the preceding layer’s local volumes. For a given layer \(l\), the input \({\textbf{x}}^{l-1}\) first undergoes layer normalization (LN) and is then processed by a multihead self-attention (MHSA) mechanism. The output of the MHSA is added to the original input via a residual connection, resulting in the intermediate output \(\hat{{\textbf{x}}}^l\). Next, \(\hat{{\textbf{x}}}^l\) is normalized again and passed through a multilayer perceptron (MLP), with another residual connection to produce the output \({\textbf{x}}^l\). The second transformer, which uses shifted partitioning, applies a shifted multihead self-attention (SMHSA) mechanism. This shifted transformer increases the connectivity between layers. The normalized output \({\textbf{x}}^l\) from the previous step is processed by the SMHSA with a residual connection, resulting in the intermediate output \(\hat{{\textbf{x}}}^{l+1}\). Finally, \(\hat{{\textbf{x}}}^{l+1}\) undergoes another normalization and passes through an MLP, with a residual connection to yield the output \({\textbf{x}}^{l+1}\). The Swin transformer block is expressed by the system of equations in Eq. (4).

$$\begin{aligned} \begin{aligned} \hat{{\textbf{x}}}^l&= \text {MHSA}\left( \text {LN}\left( {\textbf{x}}^{l-1}\right) \right) + {\textbf{x}}^{l-1}, \\ {\textbf{x}}^l&= \text {MLP}\left( \text {LN}\left( \hat{{\textbf{x}}}^l\right) \right) + \hat{{\textbf{x}}}^l, \\ \hat{{\textbf{x}}}^{l+1}&= \text {SMHSA}\left( \text {LN}\left( {\textbf{x}}^l\right) \right) + {\textbf{x}}^l, \\ {\textbf{x}}^{l+1}&= \text {MLP}\left( \text {LN}\left( \hat{{\textbf{x}}}^{l+1}\right) \right) + \hat{{\textbf{x}}}^{l+1} \end{aligned} \end{aligned}$$

(4)

The self-attention mechanism is computed using Eq. (5).

$$\begin{aligned} \text {Attention}(Q,K,V) = \text {Softmax}\left( \frac{QK^T}{\sqrt{d_k}}\right) V \end{aligned}$$

(5)

Here, \(Q\), \(K\), and \(V\) represent queries, keys, and values, respectively, and \(d_k\) is the dimension of the key and query.

Channel-wise attention-based convolution module (CACM)

In the encoder, we utilized a convolution unit based on channelwise attention, assigning distinct levels of importance to different channels, thereby enhancing feature representation. Let \(x \in {\mathbb {R}}^{C \times H \times W \times D}\) be the input feature map. We first apply both global average pooling and maximum pooling channelwise, yielding a \(\left( C, 1, 1, 1\right)\) vector, which is then concatenated. This concatenated vector undergoes a 3D convolution to an intermediate dimension, resulting in a \(\left( \frac{C}{2}, 1, 1, 1\right)\) size, followed by a GeLU activation function. This output is further processed through a second 3D convolution to restore the original channel dimension. An attention map is subsequently generated via a sigmoid activation function, which is then elementwise multiplied with the initial feature map, modulating it on the basis of channelwise attention. Finally, a third convolution is applied, downsampling the dimensions to \(\left( 2C, \frac{H}{2}, \frac{W}{2}, \frac{D}{2}\right)\), to be used in subsequent layers.

Spatial-wise attention-based convolution module (SACM)

In the decoder, we employed a convolution module that ensures spatial attention; this module focuses selectively on the salient features and regions during the reconstruction of the segmentation mask, aiding in the preservation of detailed structures and enhancing accuracy. Let \(x\) be the input feature map such that \(x \in {\mathbb {R}}^{C \times H \times W \times D}\). Initially, we apply both max-pooling and average-pooling to extract two robust feature descriptors. These descriptors are concatenated along the channel axis before undergoing a \(1 \times 1 \times 1\) convolution to yield a feature map of dimensions \((1, H, W, D)\). Next, a sigmoid activation function derives the attention map, which is then elementwise multiplied with the original input to obtain a feature map of dimensions \((C, H, W, D)\). Considering the necessity of upsampling the feature maps during the decoding phase, a transpose 3D convolution operation with a stride of 2 is utilized to upsample the features, resulting in the final feature maps of dimensions \(\left( \frac{C}{2}, 2H, 2W, 2D\right)\).

Transformer-convolution feature fusion module (TCFFM)

In the TCFFM block, the feature maps obtained from both the transformer and convolution pathways, denoted as \(x_{\text {Tr}}\) and \(x_{\text {C}}\), each belonging to the space \({\mathbb {R}}^{C \times H \times W \times D}\), are fused at each hierarchical level. Here, \(H\), \(W\), and \(D\) represent the dimensions of the feature maps, and \(C\) is the number of channels. Initially, channel-wise average pooling is applied to \(x_{\text {Tr}}\) and \(x_{\text {C}}\) to extract a representative value for each channel of the feature maps. These values are transformed into weights using a sigmoid function, generating an attention mask that enhances significant channels and suppresses less relevant channels. The results are subsequently concatenated and passed through a downsampling convolution layer, followed by a local-volume transformer block, to perform the fusion and leverage the combined strengths of both pathways in the subsequent layers.

Bottleneck

In the bottleneck, we reduced the dimensionality of the resulting feature maps from the encoder and employ a series of four global 3D transformer blocks, similar to those used in the Vision Transformer (ViT)31. These blocks perform global attention over all the downsampled feature maps. They excel at aggregating information from the entire feature map, enabling an understanding of the global context and providing a comprehensive representation to the decoder.

Attention gate

Instead of using regular concatenation in the skip connections such as those in U-Net17, we employed attention gates (AGs)30 to enhance the model’s ability to focus on target structures of varying shapes and sizes. Attention gates automatically learn to suppress irrelevant regions in an input image while highlighting salient features relevant to a specific task.

Specifically, the output of the \(l^e\)-th TCFFM of the encoder, \(X_{l}^e\), is transformed via a linear projection into a key matrix \(K_l^e\) and a value matrix \(V_l^e\). This transformation encodes the spatial and contextual information necessary for the attention mechanism. The output feature maps after the \(l^d\)-th upsampling layer of the TCFFM in the decoder, denoted \(X_{l}^d\), serve as the query \(Q_l^d\). We apply one layer of the transformer block to \(Q_l^d\), \(K_l^e\), and \(V_l^e\) in the decoder, computing self-attention as previously described for the transformer block.




Source link

Leave a Comment