geoai icon indicating copy to clipboard operation
geoai copied to clipboard

Add FocalLoss and class weighting for discrete landcover classification

Open Copilot opened this issue 5 months ago • 1 comments

Addresses severe class imbalance (e.g., roads <1% vs forests >40%) in discrete landcover training data through focal loss and automated class weighting.

Core Changes

New Loss Functions (geoai/utils.py)

  • FocalLoss: PyTorch module implementing focal loss (Lin et al. 2017) with configurable alpha/gamma parameters and class weight support
  • get_loss_function(): Factory for creating loss functions with unified configuration
  • Flexible ignore_index: Accepts int or False to handle uncertain/unlabeled pixels

Automated Class Weighting (geoai/utils.py)

  • compute_class_weights(): Scans label files to compute inverse frequency weights
  • Dual-mode operation: Inverse frequency (default) or pure custom weights via use_inverse_frequency flag
  • Custom multipliers: Fine-tune weights per class (e.g., {1: 3.0} to boost rare class)
  • Weight capping: max_weight parameter prevents training instability

Enhanced Tile Filtering (geoai/utils.py)

  • min_feature_ratio parameter in export_geotiff_tiles(): Filters tiles with insufficient labeled content
  • Tracking: Reports skipped empty and background-heavy tiles separately
  • Default: False preserves original behavior

Usage

from geoai.utils import compute_class_weights, get_loss_function, export_geotiff_tiles

# Generate tiles, skip those with <5% features
export_geotiff_tiles(
    in_raster="image.tif",
    out_folder="tiles/",
    in_class_data="labels.tif",
    skip_empty_tiles=True,
    min_feature_ratio=0.05
)

# Compute weights with custom multipliers
weights = compute_class_weights(
    labels_dir="tiles/labels/",
    num_classes=7,
    custom_multipliers={1: 3.0},  # boost roads
    max_weight=50.0
)

# Create focal loss with weights
loss_fn = get_loss_function(
    "focal",
    num_classes=7,
    use_class_weights=True,
    class_weights=weights,
    focal_alpha=0.25,
    focal_gamma=2.0
)

Backward Compatibility

All changes are opt-in. Existing code continues to work unchanged:

  • min_feature_ratio=False (default) disables filtering
  • Loss functions can still be instantiated manually
  • No changes to existing function signatures except one optional parameter

Documentation

  • examples/README_FOCAL_LOSS.md: User guide with best practices
  • examples/test_focal_loss.py: Runnable examples
  • LANDCOVER_ENHANCEMENTS.md: Implementation details
  • Unit tests in tests/test_utils.py
Original prompt

This section details on the original issue you should resolve

<issue_title>Geoai Improvement to Support Training Using Discrete Class Landcover Data</issue_title> <issue_description>

Description

Hello geoai community. I would like to propose some additional parameters I have added to the Geoai package to address the issues I had with trying to train a landcover model where my training data is sparse, instead of a continuous landcover raster (as in the provided example). My discrete landcover classification has severe class imbalance (e.g., roads <1% vs forests >40% of pixels) and needed a flexible loss functions that can handle sparse, categorical geospatial data. These changes have improved my use purpose greatly.

🚀 Proposed Enhancements Added:

  1. New Loss Function FocalLoss: Custom PyTorch implementation specifically for class imbalance Purpose: Focuses training on "hard examples" while down-weighting easy predictions Benefit: Prevents dominant classes (forests) from overwhelming rare classes (roads, wetlands) Parameters: [focal_alpha](class balancing) and [focal_gamma](hard example focus)

  2. Class Weight System Dual-mode operation: Inverse Frequency Mode: Auto-computes weights based on pixel abundance Pure Custom Mode: Manual weight control without frequency calculation Custom multipliers: Fine-tune specific landcover classes if one is being overpredicted Weight capping: Prevents extreme values that destabilize training Purpose: Ensures rare landcover types get appropriate attention during training

  3. Enhanced Tile Filtering min_feature_ratio: Filters out background-heavy tiles Purpose: Improves training efficiency by focusing on tiles with meaningful habitat content Benefit: Reduces class imbalance at the tile level, leading to better feature learning

  4. Flexible Ignore Index Configurable pixel exclusion: Handle unlabeled/uncertain areas Purpose: Skip pixels with ambiguous or missing habitat labels Benefit: Prevents model from learning incorrect patterns from uncertain data

Source code

1. New Loss Function

class FocalLoss(torch.nn.Module):
    """
    Focal Loss for addressing class imbalance in segmentation.
    
    Reference: Lin, T. Y., Goyal, P., Girshick, R., He, K., & Dollár, P. (2017).
    Focal loss for dense object detection. ICCV.
    """
    
    def __init__(self, alpha=1.0, gamma=2.0, ignore_index=-100, reduction='mean', weight=None):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.ignore_index = ignore_index
        self.reduction = reduction
        self.weight = weight  # Class weights tensor
        
    def forward(self, inputs, targets):
        """
        Args:
            inputs: Tensor of shape [N, C, H, W] where C is number of classes
            targets: Tensor of shape [N, H, W] with class indices
        """
        # Handle ignore_index parameter - if False, disable ignoring
        if self.ignore_index is False:
            ignore_idx = -100  # Use a value that won't match any target
        else:
            ignore_idx = self.ignore_index
        
        # Apply log_softmax to get log probabilities
        log_pt = F.log_softmax(inputs, dim=1)
        
        # Use NLL loss with ignore_index and class weights to handle ignored pixels properly
        ce_loss = F.nll_loss(log_pt, targets, ignore_index=ignore_idx, weight=self.weight, reduction='none')
        
        # Get probabilities by exponentiating the negative cross entropy
        pt = torch.exp(-ce_loss)
        
        # Apply focal loss formula: FL = -alpha * (1-pt)^gamma * log(pt)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        
        # Handle reduction
        if self.reduction == 'mean':
            # Only average over non-ignored pixels
            if ignore_idx != -100:
                valid_mask = (targets != ignore_idx)
                if valid_mask.sum() > 0:
                    return focal_loss[valid_mask].mean()
                else:
                    return torch.tensor(0.0, device=inputs.device, requires_grad=True)
            else:
                return focal_loss.mean()
        elif self.reduction == 'sum':
            if ignore_idx != -100:
                valid_mask = (targets != ignore_idx)
                return focal_loss[valid_mask].sum()
            else:
                return focal_loss.sum()
        else:
            return focal_loss


def get_loss_function(loss_name: str, 
                     ignore_index: Union[int, bool] = -100, 
                     num_classes: int = 2,
                     use_class_weights: bool = False,
                     class_weights: Optional[torch.Tensor] = None,
                     focal_alpha: float = 1.0,
                     focal_gamma: float = 2.0,
                     device: torch.device = None) -> torch.nn.Module:
    """Get loss function based on nam...

</details>

- Fixes opengeos/geoai#335

<!-- START COPILOT CODING AGENT TIPS -->
---

✨ Let Copilot coding agent [set things up for you](https://github.com/opengeos/geoai/issues/new?title=✨+Set+up+Copilot+instructions&body=Configure%20instructions%20for%20this%20repository%20as%20documented%20in%20%5BBest%20practices%20for%20Copilot%20coding%20agent%20in%20your%20repository%5D%28https://gh.io/copilot-coding-agent-tips%29%2E%0A%0A%3COnboard%20this%20repo%3E&assignees=copilot) — coding agent works faster and does higher quality work when set up for your repo.

Copilot avatar Oct 30 '25 17:10 Copilot

🚀 Deployed on https://6903a53d0d1aa900c749532f--opengeos.netlify.app

github-actions[bot] avatar Oct 30 '25 17:10 github-actions[bot]