mlx-swift-examples
mlx-swift-examples copied to clipboard
Rep penalty vlms
Add MaskedRepetitionContext for VLM Image Token Exclusion
Overview
This PR introduces MaskedRepetitionContext, a new LogitProcessor that extends the existing repetition penalty functionality to support excluding specific tokens (such as image tokens in Vision-Language Models) from repetition penalties.
Problem
In Vision-Language Models (VLMs), image patch tokens often need to repeat naturally to represent visual content. The existing RepetitionContext applies penalties to all repeated tokens, which can degrade VLM performance by incorrectly penalizing legitimate image token repetitions.
Solution
MaskedRepetitionContext accepts a boolean mask array that identifies which tokens should be excluded from repetition penalty calculation, allowing:
- Text tokens: Receive normal repetition penalty to maintain quality
- Image tokens: Repeat freely without penalty to preserve visual understanding
Basic Usage Example
import MLX
import MLXLMCommon
// Create a MaskedRepetitionContext processor
var processor = MaskedRepetitionContext(
repetitionPenalty: 1.1, // Apply 10% penalty to repeated tokens
repetitionContextSize: 20 // Consider last 20 tokens for repetition
)
// Example prompt tokens where token 32000 is an image token
let promptTokens = [1, 15, 32000, 32000, 42, 123] // 32000 = image token
let imageMask = [false, false, true, true, false, false] // true = exclude from penalty
// Initialize the processor with prompt and mask
let promptArray = MLXArray(promptTokens)
processor.prompt(promptArray, mask: imageMask)
// During generation: only tokens [1, 15, 42, 123] will be penalized
// Image tokens [32000, 32000] can repeat without penalty
Integration with TokenIterator
// Use with TokenIterator for generation
let sampler = CategoricalSampler(temperature: 0.7)
let iterator = try TokenIterator(
input: lmInput,
model: model,
processor: processor, // Your MaskedRepetitionContext
sampler: sampler,
maxTokens: 100
)
// Generate tokens - image tokens won't be penalized even if they repeat
for try await token in iterator {
let tokenId = token.item(Int.self)
let isImageToken = (tokenId == imageTokenId)
// Update processor with mask information for new tokens
processor.didSample(token: token, isMasked: isImageToken)
}
Files Changed
- Evaluate.swift: Added
MaskedRepetitionContextimplementation Tests/MLXLMTests/RepetitionPenaltyTests.swift: Comprehensive test suitemlx-swift-examples.xcodeproj/project.pbxproj: Added test file to build system
Key Features
✅ Backward Compatible: Implements same LogitProcessor interface as RepetitionContext
✅ Flexible Masking: Support any token types that should be excluded from penalty
✅ Efficient Implementation: Uses circular buffer with O(1) operations
✅ VLM Optimized: Designed specifically for Vision-Language Model requirements
✅ Comprehensive Testing: Full test coverage including edge cases
Testing
Running the Tests
To run the comprehensive test suite for repetition penalty functionality:
# Run all tests
xcodebuild test -scheme mlx-libraries-Package
# Run specific repetition penalty tests
xcodebuild test -scheme mlx-libraries-Package -only-testing:MLXLMTests.RepetitionPenaltyTests
What We're Testing
The test suite (RepetitionPenaltyTests.swift) validates:
testBasicRepetitionContext: Verifies existingRepetitionContextfunctionality remains intacttestMaskedRepetitionContextBasic: Tests basic masking behavior - masked tokens are excluded from penaltytestMaskedRepetitionContextAllMasked: Edge case where all tokens are masked (no penalties applied)testMaskedRepetitionContextDuringGeneration: Complex scenario simulating actual generation with mixed masked/unmasked tokenstestMaskedRepetitionContextCircularBuffer: Validates circular buffer behavior when context window is exceededtestMaskedRepetitionContextFallbackBehavior: Tests backward compatibility when no mask is providedtestMaskedRepetitionContextPreconditions: Validates error handling and input validationtestComparisonBetweenProcessors: Direct comparison betweenRepetitionContextandMaskedRepetitionContextbehavior
Test Coverage Highlights
- ✅ Penalty Application Logic: Verifies correct penalty calculation (division for positive logits, multiplication for negative)
- ✅ Mask Handling: Ensures only unmasked tokens receive penalties
- ✅ Memory Management: Tests circular buffer behavior and context window management
- ✅ Edge Cases: Handles empty contexts, all-masked scenarios, and boundary conditions
- ✅ Integration: Validates compatibility with existing MLX generation pipeline
- ✅ Performance: Confirms O(1) token operations and efficient mask processing
Benefits for VLMs
- Improved Generation Quality: VLMs can now apply repetition penalties selectively
- Better Image Understanding: Image tokens repeat naturally without artificial constraints
- Maintained Text Quality: Text tokens still receive appropriate repetition penalties
- Easy Integration: Drop-in replacement for existing repetition penalty usage
Breaking Changes
None. This is a purely additive feature that maintains full backward compatibility with existing RepetitionContext usage.