Add FocalLoss and class weighting for discrete landcover classification
Addresses severe class imbalance (e.g., roads <1% vs forests >40%) in discrete landcover training data through focal loss and automated class weighting.
Core Changes
New Loss Functions (geoai/utils.py)
-
FocalLoss: PyTorch module implementing focal loss (Lin et al. 2017) with configurablealpha/gammaparameters and class weight support -
get_loss_function(): Factory for creating loss functions with unified configuration -
Flexible
ignore_index: AcceptsintorFalseto handle uncertain/unlabeled pixels
Automated Class Weighting (geoai/utils.py)
-
compute_class_weights(): Scans label files to compute inverse frequency weights -
Dual-mode operation: Inverse frequency (default) or pure custom weights via
use_inverse_frequencyflag -
Custom multipliers: Fine-tune weights per class (e.g.,
{1: 3.0}to boost rare class) -
Weight capping:
max_weightparameter prevents training instability
Enhanced Tile Filtering (geoai/utils.py)
-
min_feature_ratioparameter inexport_geotiff_tiles(): Filters tiles with insufficient labeled content - Tracking: Reports skipped empty and background-heavy tiles separately
-
Default:
Falsepreserves original behavior
Usage
from geoai.utils import compute_class_weights, get_loss_function, export_geotiff_tiles
# Generate tiles, skip those with <5% features
export_geotiff_tiles(
in_raster="image.tif",
out_folder="tiles/",
in_class_data="labels.tif",
skip_empty_tiles=True,
min_feature_ratio=0.05
)
# Compute weights with custom multipliers
weights = compute_class_weights(
labels_dir="tiles/labels/",
num_classes=7,
custom_multipliers={1: 3.0}, # boost roads
max_weight=50.0
)
# Create focal loss with weights
loss_fn = get_loss_function(
"focal",
num_classes=7,
use_class_weights=True,
class_weights=weights,
focal_alpha=0.25,
focal_gamma=2.0
)
Backward Compatibility
All changes are opt-in. Existing code continues to work unchanged:
-
min_feature_ratio=False(default) disables filtering - Loss functions can still be instantiated manually
- No changes to existing function signatures except one optional parameter
Documentation
-
examples/README_FOCAL_LOSS.md: User guide with best practices -
examples/test_focal_loss.py: Runnable examples -
LANDCOVER_ENHANCEMENTS.md: Implementation details - Unit tests in
tests/test_utils.py
Original prompt
This section details on the original issue you should resolve
<issue_title>Geoai Improvement to Support Training Using Discrete Class Landcover Data</issue_title> <issue_description>
Description
Hello geoai community. I would like to propose some additional parameters I have added to the Geoai package to address the issues I had with trying to train a landcover model where my training data is sparse, instead of a continuous landcover raster (as in the provided example). My discrete landcover classification has severe class imbalance (e.g., roads <1% vs forests >40% of pixels) and needed a flexible loss functions that can handle sparse, categorical geospatial data. These changes have improved my use purpose greatly.
🚀 Proposed Enhancements Added:
-
New Loss Function FocalLoss: Custom PyTorch implementation specifically for class imbalance Purpose: Focuses training on "hard examples" while down-weighting easy predictions Benefit: Prevents dominant classes (forests) from overwhelming rare classes (roads, wetlands) Parameters: [focal_alpha](class balancing) and [focal_gamma](hard example focus)
-
Class Weight System Dual-mode operation: Inverse Frequency Mode: Auto-computes weights based on pixel abundance Pure Custom Mode: Manual weight control without frequency calculation Custom multipliers: Fine-tune specific landcover classes if one is being overpredicted Weight capping: Prevents extreme values that destabilize training Purpose: Ensures rare landcover types get appropriate attention during training
-
Enhanced Tile Filtering min_feature_ratio: Filters out background-heavy tiles Purpose: Improves training efficiency by focusing on tiles with meaningful habitat content Benefit: Reduces class imbalance at the tile level, leading to better feature learning
-
Flexible Ignore Index Configurable pixel exclusion: Handle unlabeled/uncertain areas Purpose: Skip pixels with ambiguous or missing habitat labels Benefit: Prevents model from learning incorrect patterns from uncertain data
Source code
1. New Loss Function
class FocalLoss(torch.nn.Module):
"""
Focal Loss for addressing class imbalance in segmentation.
Reference: Lin, T. Y., Goyal, P., Girshick, R., He, K., & Dollár, P. (2017).
Focal loss for dense object detection. ICCV.
"""
def __init__(self, alpha=1.0, gamma=2.0, ignore_index=-100, reduction='mean', weight=None):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.ignore_index = ignore_index
self.reduction = reduction
self.weight = weight # Class weights tensor
def forward(self, inputs, targets):
"""
Args:
inputs: Tensor of shape [N, C, H, W] where C is number of classes
targets: Tensor of shape [N, H, W] with class indices
"""
# Handle ignore_index parameter - if False, disable ignoring
if self.ignore_index is False:
ignore_idx = -100 # Use a value that won't match any target
else:
ignore_idx = self.ignore_index
# Apply log_softmax to get log probabilities
log_pt = F.log_softmax(inputs, dim=1)
# Use NLL loss with ignore_index and class weights to handle ignored pixels properly
ce_loss = F.nll_loss(log_pt, targets, ignore_index=ignore_idx, weight=self.weight, reduction='none')
# Get probabilities by exponentiating the negative cross entropy
pt = torch.exp(-ce_loss)
# Apply focal loss formula: FL = -alpha * (1-pt)^gamma * log(pt)
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
# Handle reduction
if self.reduction == 'mean':
# Only average over non-ignored pixels
if ignore_idx != -100:
valid_mask = (targets != ignore_idx)
if valid_mask.sum() > 0:
return focal_loss[valid_mask].mean()
else:
return torch.tensor(0.0, device=inputs.device, requires_grad=True)
else:
return focal_loss.mean()
elif self.reduction == 'sum':
if ignore_idx != -100:
valid_mask = (targets != ignore_idx)
return focal_loss[valid_mask].sum()
else:
return focal_loss.sum()
else:
return focal_loss
def get_loss_function(loss_name: str,
ignore_index: Union[int, bool] = -100,
num_classes: int = 2,
use_class_weights: bool = False,
class_weights: Optional[torch.Tensor] = None,
focal_alpha: float = 1.0,
focal_gamma: float = 2.0,
device: torch.device = None) -> torch.nn.Module:
"""Get loss function based on nam...
</details>
- Fixes opengeos/geoai#335
<!-- START COPILOT CODING AGENT TIPS -->
---
✨ Let Copilot coding agent [set things up for you](https://github.com/opengeos/geoai/issues/new?title=✨+Set+up+Copilot+instructions&body=Configure%20instructions%20for%20this%20repository%20as%20documented%20in%20%5BBest%20practices%20for%20Copilot%20coding%20agent%20in%20your%20repository%5D%28https://gh.io/copilot-coding-agent-tips%29%2E%0A%0A%3COnboard%20this%20repo%3E&assignees=copilot) — coding agent works faster and does higher quality work when set up for your repo.
🚀 Deployed on https://6903a53d0d1aa900c749532f--opengeos.netlify.app