Improve generated code for `TryFromPrimitive`
Motivation
While inspecting the code generated by num_enum, I found that the code for TryFromPrimitive is suboptimal. In certain situations, the generated code leads to generated assembly that is both larger (binary size) and slower than simple handwritten code using transmute. This is very unfortunate since one of the main benefits of num_enum is that it allows low-level enum "serialization" without the use of unsafe.
Description
The generated code uses a large match statement to match input values. This is a good general solution, but can lead to inefficient code gen for certain enums.
In this issue, I will focus on enums where the discriminant value of all variants are statically known (= the derive macro knows the numerical value for each variant) and each variant has exactly one discriminant (= no alternatives). I will also ignore default and catch_all, but the idea I talk about can easily be adopted to support them.
Inefficient code gen
Consider the following enum of the DXGI_FORMAT enumeration:
#[derive(TryFromPrimitive)]
#[repr(u8)]
enum DxgiFormatEnum {
UNKNOWN = 0,
R32G32B32A32_TYPELESS = 1,
// ... 113 more variants
B4G4R4A4_UNORM = 115,
P208 = 130,
V208 = 131,
V408 = 132,
A4B4G4R4_UNORM = 191,
}
The generated implementation for TryFromPrimitive::try_from_primitive currently looks like this (simplified):
fn try_from_primitive(number: u8) -> Result<DxgiFormatEnum, TryFromPrimitiveError> {
#![allow(non_upper_case_globals)]
const UNKNOWN__num_enum_0__: u8 = 0;
const R32G32B32A32_TYPELESS__num_enum_0__: u8 = 1;
// ... and so on
const B4G4R4A4_UNORM__num_enum_0__: u8 = 115;
const P208__num_enum_0__: u8 = 130;
const V208__num_enum_0__: u8 = 131;
const V408__num_enum_0__: u8 = 132;
const A4B4G4R4_UNORM__num_enum_0__: u8 = 191;
#[deny(unreachable_patterns)]
match number {
UNKNOWN__num_enum_0__ => Ok(DxgiFormatEnum::UNKNOWN),
R32G32B32A32_TYPELESS__num_enum_0__ => Ok(DxgiFormatEnum::R32G32B32A32_TYPELESS),
// ... and so on
B4G4R4A4_UNORM__num_enum_0__ => Ok(DxgiFormatEnum::B4G4R4A4_UNORM),
P208__num_enum_0__ => Ok(DxgiFormatEnum::P208),
V208__num_enum_0__ => Ok(DxgiFormatEnum::V208),
V408__num_enum_0__ => Ok(DxgiFormatEnum::V408),
A4B4G4R4_UNORM__num_enum_0__ => Ok(DxgiFormatEnum::A4B4G4R4_UNORM),
#[allow(unreachable_patterns)]
_ => Err(TryFromPrimitiveError(number)),
}
}
And this generates the following assembly on Rust 1.85.0 for O1, O2, O3, Os, and Oz: (Compiler Explorer link)
try_from_primitive:
mov al, 1
cmp dil, -65 ; compare the input primitive with 191
ja .LBB0_3 ; if the input is >191, go to the Err branch
movzx ecx, dil
lea rdx, [rip + .LJTI0_0]
movsxd rcx, dword ptr [rdx + 4*rcx]
add rcx, rdx
jmp rcx
.LBB0_2: ; Ok branch
xor eax, eax
.LBB0_3: ; Err branch
mov edx, edi
ret
.LJTI0_0:
.long .LBB0_2-.LJTI0_0 ; 0 = UNKNOWN
.long .LBB0_2-.LJTI0_0 ; 1 = R32G32B32A32_TYPELESS
; ...
.long .LBB0_2-.LJTI0_0 ; 115 = B4G4R4A4_UNORM__num_enum_0
.long .LBB0_3-.LJTI0_0 ; 116 Err
; ...
.long .LBB0_3-.LJTI0_0 ; 129 Err
.long .LBB0_2-.LJTI0_0 ; 130 = P208
.long .LBB0_2-.LJTI0_0 ; 131 = V208
.long .LBB0_2-.LJTI0_0 ; 132 = V408
.long .LBB0_3-.LJTI0_0 ; 133 Err
; ...
.long .LBB0_3-.LJTI0_0 ; 190 Err
.long .LBB0_2-.LJTI0_0 ; 191 = A4B4G4R4_UNORM
I annotated the assembly a little, but it basically works in 2 steps:
- Check whether the input primitive is
> 191. If it is, then returnErr. - Otherwise, look up whether to jump to the
OkorErrbranch in a jump table.
I abbreviated the jump table here, but it has 192 entries (each 4 bytes). So to retrieve 1 bit of information, we need to read a jump table entry and then jmp to the retrieved location.
This is quite inefficient. Both the memory read and jmp are comparatively slow, and the jump table is 768 bytes in binary size.
More efficient code
Compare this to the following implementation I've written by hand:
fn try_from_primitive(number: u8) -> Result<DxgiFormatEnum, TryFromPrimitiveError> {
let is_valid = matches!(number, 0..=115 | 130..=132 | 191);
if is_valid {
// SAFETY: `number` is a valid discriminant value
Ok(unsafe { std::mem::transmute(number) })
} else {
Err(TryFromPrimitiveError(number))
}
}
try_from_primitive:
cmp dil, 116 ; compare input to 116
jae .LBB0_2 ; if the `input >=116`, go to LBB0_2
xor eax, eax ; Otherwise, return Ok
mov edx, edi
ret
.LBB0_2:
lea ecx, [rdi + 126] ; compute `input - 130`
mov al, 1
cmp cl, 61 ; compare `input - 130` to 61 (191 - 130 = 61)
ja .LBB0_4 ; if `input - 130 > 61`, return Err
movabs rax, 2305843009213693944 ; this value is used as a bitset
shr rax, cl
.LBB0_4:
mov edx, edi
ret
This is the output for O3 with other optimization modes producing virtually identical assembly. I want to point out that what the compiler produced is quite good here. It handled 0..=115 in its own branch and used a fast bitset lookup for 130..=132 | 191. Note that there is no jump table or anything. The assembly is just a few bytes and can easily be inlined.
The underlying issue
The underlying issue is two-fold:
- LLVM has a hard time optimizing the large
matchstatement generated bynum_enum. - LLVM is inconsistent when it optimizes the
matchstatement generated bynum_enum.
I want to expand on the second point. LLVM actually produces very efficient assembly for the match statement generated by num_enum IF the enum is sufficiently simple.
Example:
#[derive(IntoPrimitive, TryFromPrimitive)]
#[repr(u8)]
enum Simple3 {
A = 0,
B = 1,
C = 2,
}
try_from_primitive:
cmp dil, 3
setae al
mov edx, edi
ret
However, even relatively simple enums can cause large jump tables to be generated.
Example:
#[derive(IntoPrimitive, TryFromPrimitive)]
#[repr(u8)]
enum SimpleJumpTable {
A = 0,
B = 1,
C = 2,
D = 3,
E = 4,
F = 5,
Large = 65,
}
try_from_primitive:
mov al, 1
cmp dil, 64
ja .LBB0_3
movzx ecx, dil
lea rdx, [rip + .LJTI0_0]
movsxd rcx, dword ptr [rdx + 4*rcx]
add rcx, rdx
jmp rcx
.LBB0_2:
xor eax, eax
.LBB0_3:
mov edx, edi
ret
.LJTI0_0:
.long .LBB0_2-.LJTI0_0
.long .LBB0_2-.LJTI0_0
; 63 more entries
Even small changes in discriminant values can have huge effects on the generated assembly.
Here is a Compiler Explorer link where you quickly see the generated assembly for different variants. Just input the discriminant values into the test! macro.
Suggested fix
Since the compiler seems to be bad at optimizing the large match statement generated by num_enum, I would suggest splitting TryFromPrimitive::try_from_primitive into 2 parts:
- The first part will determine whether the given input is a valid discriminant value.
- The second part uses the result from the first part and transmutes the input or returns an error (or the
default/catch_allvariant).
In code, it would have the following form:
fn try_from_primitive(number: Self::Primitive) -> Result<Self, TryFromPrimitiveError> {
const Variant0: u8 = ...;
const Variant1: u8 = ...;
const Variant2: u8 = ...;
const VariantN: u8 = ...;
let is_valid = matches!(number, Variant0 | Variant1 | Variant2 | ... | VariantN);
if is_valid {
// SAFETY: `number` is a valid discriminant value
Ok(unsafe { std::mem::transmute(number) })
} else {
Err(TryFromPrimitiveError(number))
}
}
The compiler generally seems to be better at optimizing this form, leading to assembly that is as good or better than of the current code gen. You can test this out with this Compiler Explorer link.
Note however that matches! is not a perfect solution. It also generates jump tables sometimes, it just does it less often. To fix this, I would further suggest manually choosing a good strategy for determining is_valid instead of leaving it entirely up to the compiler. I would suggest the following:
Let V be the set of all discriminant values. The determine is_valid: bool as follows:
- If
Vis an interval, generateis_valid = (V_min..=V_max).contains(&input). - If
V_max - V_min <= 64, then use a 64-bit bitset. - Otherwise, sort
Vand output amatches!with ranges instead of simple literal values. E.g.matches!(1..=4, 6..=8, 100)instead ofmatches!(1, 2, 3, 4, 6, 7, 8, 100). LLVM tends to produce better assembly like this.
This approach should always produce assembly that is as good or better than simply using matches! with all variants.
Note: If you choose to not implement the bitset method and want to rely on LLVM instead, I would advise aginst using matches! with ranges for small enums. LLVM tends to optimize those with bitset, unless ranges are used in macthes!. (Again, LLVM is really inconsistent in how it optimizes this.)