Notice
Recent Posts
Recent Comments
Link
«   2025/04   »
1 2 3 4 5
6 7 8 9 10 11 12
13 14 15 16 17 18 19
20 21 22 23 24 25 26
27 28 29 30
Archives
Today
Total
관리 메뉴

잡동사니 블로그

Dice Coefficient(Dice Score) -> Dice Loss Function 본문

공부용

Dice Coefficient(Dice Score) -> Dice Loss Function

코딩부대찌개 2024. 10. 9. 21:37

Dice Coefficient는 Image Segmentation에서 사용되며, 예측된 Mask와 실제 Mask 사이의 겹치는 정도를 평가하는 지표.

$$\text{ Dice Coefficient } = \frac{2 \times |A \cap B|}{|A| + |B|}$$

예를 들어, 실제 값은 100개의 1이 있고, 예측된 값은 50개의 1이 있으며, 이 중 30개가 겹친다면.

$$\text{ Dice Coefficient } = \frac{2 \times 30}{100 + 50} = 0.4$$

 

정확히 일치하면 1이기 때문에 1에 가까울 수록 많이 겹친다는 의미이다.

 

즉 Dice Loss FunctionDice Coefficient를 기반으로 하는 손실 함수이며 $1 - Dice Coefficient$를 통해 값이 점점 줄어들도록 학습 진행

 

import torch
import torch.nn as nn

class Dice_Coef_Loss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(Dice_Coef_Loss, self).__init__()
        self.smooth = smooth

    def forward(self, y_pred, y_true):
        batch_size = y_true.shape[0]
        y_pred = torch.clamp(y_pred, 0, 1)
        y_true_f = y_true.view(batch_size, -1)
        y_pred_f = y_pred.view(batch_size, -1)

        intersection = torch.sum(y_true_f * y_pred_f, dim=-1)
        mask_sum = torch.sum(y_true_f, dim=-1) + torch.sum(y_pred_f, dim=-1)
        dice = (2. * intersection + self.smooth) / (mask_sum + self.smooth)
        
        return 1 - torch.mean(dice)
        
#Smoth는 zero division을 방지하기 위함