amaranth
amaranth copied to clipboard
[WIP] Implement RFC 41: `lib.fixed`
Overview
- This is a staging ground for experimenting with the ergonomics of
lib.fixedwhile the associated RFC is being worked on. It started as a fork of @zyp's early implementation of the RFC here, however a few things have changed since then.- Most importantly, this PR adds some tests, which makes it more obvious what the consequences of different design decisions will be.
- Also, all operators required for real-world use are now implemented
- As of now, this PR adheres to the latest version of the RFC, with some minor changes (these will be discussed further in the RFC issue):
- New methods on
fixed.Value-.saturate()and.clamp()- these are commonly needed in my DSP codebase, but it may make sense to punt these new methods to a future RFC. .truncate()is added as an alias for.reshape(), the only difference being that it verifies there was a reduction off_bitsrequested.- The
numerator()method is dropped, as I found a way to combine it withas_value()reliably.
- New methods on
- I have already integrated this implementation of
lib.fixedin this Tiliqua PR and tested it underneath my library of DSP cores, and will continue to use the learnings there in order to guide the RFC. - It should be obvious that this PR needs a cleanup pass, improved diagnostics, and a lot of documentation work after the RFC is complete. However, it is already usable in real projects as all essential operations are implemented.
Simple example
Consider the implementation of a simple Low-pass filter, where we wish to compute the difference equation y = y[n-1] * beta + x * (1 - beta) using fixed point representation:
class OnePole(wiring.Component):
def __init__(self, beta=0.9, sq=fixed.SQ(1, 15)):
self.beta = beta
self.sq = sq
super().__init__({
"x": In(sq),
"y": Out(sq),
})
def elaborate():
m = Module()
a = fixed.Const(self.beta, shape=self.sq)
b = fixed.Const(1-self.beta, shape=self.sq)
# Quantization from wider to smaller fixed.Value occurs on the `.eq()`
m.d.sync += self.y.eq(self.y*a + self.x*b)
return m
I noticed == seems to work differently than standard types. Is this intended? I noticed this while writing assertions:
from amaranth import *
from amaranth.lib import fixed
from amaranth.lib.wiring import Component, In, Out
class C(Component):
o: Out(1)
def elaborate(self, platform):
m = Module()
self.s = Signal(2)
self.f = Signal(fixed.SQ(2, 2))
m.d.sync += self.o.eq(self.s)
return m
from amaranth.sim import Period, Simulator
dut = C()
async def bench(ctx):
for _ in range(1):
print(ctx.get(dut.s) == 0) # True
print(ctx.get(dut.f) == 0) # (== (s (slice (const 4'sd0) 0:4)) (cat (const 2'd0) (const 1'd0)))
await ctx.tick()
sim = Simulator(dut)
sim.add_clock(Period())
sim.add_testbench(bench)
sim.run()
outputs:
True
(== (s (slice (const 4'sd0) 0:4)) (cat (const 2'd0) (const 1'd0)))
@goekce Thanks for taking a look! I think this is happening because of the comparison outside the simulation context, rather than inside it, which means that from_bits is converting the type back up to a fixed.SQ. I see similar behaviour with other types from amaranth.lib, for example enum:
from amaranth import *
from amaranth.lib import fixed, enum
from amaranth.lib.wiring import Component, In, Out
class Funct4(enum.Enum, shape=unsigned(4)):
ADD = 0
SUB = 1
MUL = 2
class C(Component):
o: Out(1)
def elaborate(self, platform):
m = Module()
self.s = Signal(2)
self.e = Signal(Funct4)
self.f = Signal(fixed.SQ(2, 2))
m.d.sync += self.o.eq(self.s)
return m
from amaranth.sim import Period, Simulator
dut = C()
async def bench(ctx):
for _ in range(1):
# raw result of from_bits()
print(ctx.get(dut.s)) # 0
print(ctx.get(dut.e)) # Funct4.ADD
print(ctx.get(dut.f)) # fixed.SQ(2, 2) (const 4'sd0)
# compare inside the simulation context
print(ctx.get(dut.s == 0)) # 1
print(ctx.get(dut.e == Funct4.ADD)) # 1
print(ctx.get(dut.f == 0)) # 1
print(ctx.get(dut.f == fixed.Const(0, shape=fixed.SQ(2, 2)))) # 1
# compare outside the simulation context
print(ctx.get(dut.e) == 0) # False (even though Funct4.ADD is 0)
print(ctx.get(dut.f) == 0) # (== (s (slice (const 4'sd0) 0:4)) (cat (const 2'd0) (const 1'd0)))
await ctx.tick()
sim = Simulator(dut)
sim.add_clock(Period())
sim.add_testbench(bench)
sim.run()
As an example for writing assertions, maybe take a look at the tests attached to this PR. That being said, this is still a work in progress and not quite ready for review yet!
@goekce a related topic to your example's evaluation outside the simulation context is elaboration-time evaluation of constant expressions. which is something I'd like to leave out of this RFC, even if we could add it in the future. For example
>>> from amaranth import *
>>> Const(5) + Const(3)
(+ (const 3'd5) (const 2'd3))
>>> Const(5) == Const(3)
(== (const 3'd5) (const 2'd3))
I was not aware of the possibility that I can do a simulation-time comparison. Thanks Seb, this would solve my problem.
As I understand, elaboration-time evaluation of constant expressions is extra work.👍
You can already do elaboration-time comparison of const expressions by doing Const.cast(a).value == Const.cast(b).value.
For comparison with zero that makes sense but for other values, if we drop the fixed.Shape this kind of comparison is not so intuitive
>>> Const.cast(fixed.Const(0.5, shape=fixed.SQ(2, 2))).value == 2
True
I think this kind of use case is better covered by as_float
fixed.Const(0.5, shape=fixed.SQ(2, 2)).as_float() == 0.5
In the example from @goekce - if we really want to do the comparison outside the sim context for some reason
print(ctx.get(dut.f).as_float() == 0.0) # True
What I mean is that Const.cast() is the Amaranth entry point for e.g. resolving concantenations or indexing into an Amaranth constant value. It is up to the end user or library how this API is used, I only wanted to mention that it is available.
I am confused about how Const.cast() would be valuable here.
I thought Const.cast() is actually used in the following expression so that the signal provides directly an output that can be used in elaboration-time:
(Pdb) ctx.get(dut.s)
0
But it is probably not used:
(Pdb) Const.cast(dut.s)
*** TypeError: Value (sig s) cannot be converted to an Amaranth constant
Or do you just want to say that Const.cast() could be used for fixed values whenever an expression like with m.Case(...): as documented in the RFC about Const.cast()?
If a Mux is an operand, the arithmetic result is wrong. I did not have the time to investigate further, but have the feeling that Mux cannot return that the result should be interpreted as a fixed shape:
self.a should output 0.25, but it does not:
from amaranth import *
from amaranth.lib import fixed
from amaranth.lib.wiring import Component, In, Out
class C(Component):
y: Out(fixed.SQ(8, 4))
def elaborate(self, platform):
m = Module()
self.a = Signal(self.y.shape())
self.b = Signal(self.y.shape())
self.c = Signal(self.y.shape())
m.d.comb += [
self.a.eq(fixed.Const(-1.125) + Mux(1, fixed.Const(1.375), 0)),
self.b.eq(fixed.Const(-1.125) + fixed.Const(1.375)),
self.c.eq(Mux(1, fixed.Const(1.375), 0)),
]
m.d.sync += self.y.eq(self.a)
return m
from amaranth.sim import Period, Simulator
dut = C()
async def bench(ctx):
await ctx.tick()
print(ctx.get(dut.a).as_float())
print(ctx.get(dut.b).as_float())
print(ctx.get(dut.c).as_float())
print(Mux(1, fixed.Const(1.375), 0).shape())
sim = Simulator(dut)
sim.add_clock(Period())
sim.add_testbench(bench)
sim.run()
Result:
9.875
0.25
0.6875
unsigned(4)
If a Mux is an operand, the arithmetic result is wrong. I did not have the time to investigate further, but have the feeling that Mux cannot return that the result should be interpreted as a fixed shape:
Thanks, I haven't played with Mux much in combination with fixed.Value, this makes a good candidate for some more test cases. There are some tests on the fixed.Value(shape, Mux(a, b, c)) statement inside fixed.Value.clamp(...), but in this case the 2 shapes are guaranteed to be the same, so I guess Mux losing the underlying type did not have an adverse effect and so I didn't catch it. In your example, forcing the incoming type to match also 'fixes' it: self.c.eq(Mux(1, fixed.Const(1.375, self.y.shape()), 0))
Ideally we want to attack this without touching anything outside lib.fixed, on a quick skim I think the information is lost on the cast to Value here. This also looks related to Shape._unify - we kind of need a similar _unify operation to automatically perform a .reshape() up to max(f_bits) in the same fashion as fixed.Value._binary_op already does.
It seems other types similarly lose the type information through Mux:
from amaranth.lib import enum, data
class Funct(enum.Enum, shape=4):
ADD = 0
SUB = 1
MUL = 2
rgb565_layout = data.StructLayout({
"red": 5,
"green": 6,
"blue": 5
})
print(Mux(1, Funct.ADD, Funct.MUL).shape()) # unsigned(4)
print(Mux(1, Signal(rgb565_layout), Signal(rgb565_layout)).shape()) # unsigned(16)
In this case it seems making Mux more intelligent to what is happening to fixed.Value would be different to how existing types behave. What do you think @whitequark ? I would like to avoid touching the infrastructure underneath Mux if we can here, and then we could address preserving shapes through Mux in a future RFC?
I would like to avoid touching the infrastructure underneath
Muxif we can here, and then we could address preserving shapes throughMuxin a future RFC?
Yeah, that sounds reasonable. We actually discussed the option of preserving shapes through Mux and SwitchValue; @wanda-phi, do you recall what came out of those discussions?
In your example, forcing the incoming type to match also 'fixes' it:
self.c.eq(Mux(1, fixed.Const(1.375, self.y.shape()), 0))
If Mux is an operand, then forcing does not fix 😕
...
self.c.eq(Mux(1, fixed.Const(1.375), 0)),
self.d.eq(Mux(1, fixed.Const(1.375, self.y.shape()), 0)),
self.e.eq(fixed.Const(-1.125) + Mux(1, fixed.Const(1.375, self.y.shape()), 0)),
self.f.eq(fixed.Const(-1.125, self.y.shape()) + Mux(1, fixed.Const(1.375, self.y.shape()), 0)),
self.g.eq(fixed.Const(-1.125, self.y.shape()) + Mux(1, fixed.Const(1.375, self.y.shape()), fixed.Const(0, self.y.shape()))),
...
print(ctx.get(dut.d).as_float())
...
Outputs:
0.6875
1.375
20.875
20.875
20.875
The workaround I found was to use an m.If and create two different arithmetic assignments.
The workaround I found was to use an m.If and create two different arithmetic assignments.
If one really wants to use a Mux that is lib.fixed aware, an easier workaround might be to build this on top of the primitives already supplied by lib.fixed, for example:
def FMux(test, a, b):
if isinstance(a, fixed.Value) and isinstance(b, fixed.Value):
f_bits = max(a.f_bits, b.f_bits)
return fixed.Value(a.shape() if a.i_bits >= b.i_bits else b.shape(),
Mux(test, a.reshape(f_bits), b.reshape(f_bits)))
elif isinstance(a, fixed.Value):
return fixed.Value(a.shape(), Mux(test, a, b))
elif isinstance(b, fixed.Value):
return fixed.Value(b.shape(), Mux(test, a, b))
else:
raise TypeError("FMux should only be used on fixed.Value")
On your example:
self.a.eq(fixed.Const(-1.125) + FMux(1, fixed.Const(1.375), 0)),
self.b.eq(fixed.Const(-1.125) + fixed.Const(1.375)),
self.c.eq(FMux(1, fixed.Const(1.375), 0)),
self.d.eq(fixed.Const(-1.125) + FMux(1, fixed.Const(1.375), 0)),
self.e.eq(fixed.Const(-1.125) + FMux(1, fixed.Const(1.375), 0)),
self.f.eq(fixed.Const(-1.125) + FMux(1, fixed.Const(1.375), fixed.Const(0))),
# Prints
fixed.Const(0.25, SQ(8, 4))
fixed.Const(0.25, SQ(8, 4))
fixed.Const(1.375, SQ(8, 4))
fixed.Const(0.25, SQ(8, 4))
fixed.Const(0.25, SQ(8, 4))
fixed.Const(0.25, SQ(8, 4))
I would however not include such a workaround in this RFC. I think attacking this properly would imply modifying Mux to preserve shapes. Which should be a separate RFC, as it has wider consequences than just adding a module to the standard library as lib.fixed aims to do.
Whether lib.fixed should be blocked by a shape-preserving Mux implementation is not something I have a strong opinion on, at least in my DSP codebases I haven't found much use for Mux outside of .clamp() which is provided here and is already shape-preserving.
Thanks for the FMux. I will try it.
IMHO fixed could attract other users, so I would not block it, if shape preserving Mux has not a high priority right now. Mux issue should then be noted though.
-
I tried
FMuxand it works 🙂. I find it a short way of expressing statements likez = a + (b if c else d) + einstead of writingif ... then z = a + b + e else z = a + d + e. -
I noticed:
In [30]: fixed.Const(-2) Out[30]: fixed.Const(-2.0, SQ(3, 0))-2 requires only two bits in two's complement. Is it intended that it generates 3 bits?
-2 requires only two bits in two's complement. Is it intended that it generates 3 bits?
Agree this is an edgecase in const size inference. Let me add a test case for that and fix. Thanks! :+1:
I additionally noticed that reshape(shape) from the RFC is not implemented. Is this intended?
Yes this is intended, my plan was to change the RFC text to reflect this, as I couldn't find a good usecase for the second form of reshape(). I touched on it in the RFC PR here: https://github.com/amaranth-lang/rfcs/pull/41#discussion_r1855276571