gemma icon indicating copy to clipboard operation
gemma copied to clipboard

Add nucleus sampling (`top-p`) method

Open stefanoamorelli opened this issue 7 months ago • 1 comments

Implement nucleus sampling (top-p sampling) as a new sampling method in the Gemma text generation toolkit. This addresses the a gap in gemma/gm/text/__init__.py:34 and provides the missing sampling strategy.

Background

Nucleus sampling was introduced in "The Curious Case of Neural Text Degeneration" (Holtzman et al., 2020) and is a dynamic sampling technique that has gained popularity for high-quality generation for modern LLMs .

Current State

The Gemma library currently supports:

  • ✅ Greedy sampling (Greedy)
  • ✅ Random sampling (RandomSampling)
  • ✅ Top-k sampling (TopkSampling)
  • Nucleus sampling

Problem

Each existing method has limitations:

Method Issue
Greedy Repetitive, deterministic output
Random May sample very unlikely tokens, leading to incoherent text
Top-k Fixed candidate set size doesn't adapt to context uncertainty

Proposed solution

Add NucleusSampling class that:

  1. Dynamically selects candidates based on cumulative probability mass;
  2. Adapts to context - uses fewer tokens when model is confident, more when uncertain.

Technical details

Algorithm

flowchart TD
    A[Convert logits → probabilities temperature scaling] --> B[Sort tokens by probability descending]
    B --> C[Find nucleus: smallest set where cumulative prob ≤ p]
    C --> D[Filter out tokens outside nucleus]
    D --> E[Renormalize remaining probabilities]
    E --> F[Sample from the filtered distribution]

API design

@dataclasses.dataclass(frozen=True, kw_only=True)
class NucleusSampling(SamplingMethod):
    temperature: float = 1.0  # Temperature scaling
    p: float = 0.9           # Nucleus threshold (0.0-1.0)
    
    def get_next_tokens(self, logits, rng) -> tokens

Usage

import gemma.gm as gm

# Conservative (factual text)
sampler = gm.text.NucleusSampling(p=0.7, temperature=0.8)

# Balanced (general purpose) 
sampler = gm.text.NucleusSampling(p=0.9, temperature=1.0)

# Creative (diverse output)
sampler = gm.text.NucleusSampling(p=0.95, temperature=1.2)

References

stefanoamorelli avatar May 25 '25 18:05 stefanoamorelli

Hi @stefanoamorelli ,

Thank you so much for your contribution and for reporting the issue. We appreciate you submitting this PR. Your patience while it is reviewed is greatly valued.

Thanks.