circt icon indicating copy to clipboard operation
circt copied to clipboard

[FIRRTL] Subword assignment support via rewriting to read-modify-write

Open zyedidia opened this issue 1 year ago • 21 comments

This implements subword assignment by allowing the BitsPrimOp as a sink and then rewriting connections to it during the expand-whens pass.

When a value is subword-assigned without having a previous connection, an "uninitialized" wire is created to represent the previous value. During initialization checking, the checker makes sure that no uninitialized wires are reachable from any connection in the module. Some additional canonicalizations of the bits primop are needed to make sure that uses of uninitialized wires are eliminated if all bits are assigned. The canonicalizations are implemented in ExpandWhens because they are necessary for the pass. I also added them to FIRRTLFolds.cpp because they might be generally useful, but I'm not sure how to use the ones from FIRRTLFolds during expand-whens because I couldn't figure out how to run canonicalizations during a pass rather than as a separate pass. There's probably a better solution here, so let me know if there's something I should change.

The parser has also been modified to support new syntax for BitsPrimOp: x[hi:lo] is equivalent to bits(x, hi, lo), and x[i] is equivalent to bits(x, i, i). I'm not sure if I made this modification in the best way, or if we should use this new syntax at all, rather than just supporting bits(x, hi, lo) <= (currently this syntax is also supported). It uses the subaccess cache for the single bit index, but not for the range index.

This implementation does not accept any combinational loops.

Let me know what you think, thanks!

zyedidia avatar Aug 03 '22 20:08 zyedidia

would be good to have some parser (parse-basic.fir) and flow checking tests (connect-errors.mlir).

youngar avatar Aug 03 '22 23:08 youngar

I removed the recursion from canonicalizeBits, so I think all the comments have been addressed, except the one regarding code duplication with FIRRTLFolds. I propose to just remove the new code from FIRRTLFolds and leave that file unchanged. Let me know if you'd like more changes on the code that has been changed since the earlier comments. Thanks!

zyedidia avatar Aug 15 '22 21:08 zyedidia

I am having a hard time parsing the MLIR-internal firrtl syntax, but as far as I can tell the following kinds of tests might be missing:

  • subword assignment to a field in a Bundle
  • subword assignment to a field in a Vec
  • subword assignment to a register
  • subword assignment to a memory data port

Ideally you would add the product of {Bundle field, Vec field, UInt, SInt} x {Port, Wire, Reg, Memory Port} as tests. Maybe there are some combinations that would make sense to skip.

ekiwi avatar Aug 22 '22 20:08 ekiwi

There is a test for subword assignment to a register. Bundles and Vecs have been lowered away by the time this pass runs, so I'm not sure if we need tests for subword assignments to those, but if so they would have to go in the end-to-end firtool test. I have a memory port test using FIRRTL memories that I could add (it would be a little large).

In investigating the memory ports though I realized that there is an issue with subword assignment to CHIRRTL memory ports. The LowerCHIRRTL pass runs before this pass and is not able to handle a bits op on the left-hand-side. I'm not sure how to address this, since I am not familiar with CHIRRTL memories (is there documentation on how they work anywhere?). Currently the LowerCHIRRTL pass will give a flow error if a bits op is used for subword assignment.

zyedidia avatar Aug 22 '22 22:08 zyedidia

I am curious, with this PR, are you able to simplify:

circuit subword_2:
  module subword_2:
    input c : UInt<1>
    output x : UInt<2>
    when c:
      x[0] is invalid
      x[1] <= UInt(1)
    else:
      x[0] <= UInt(1)
      x[1] is invalid

to

circuit subword_2:
  module subword_2:
    input c : UInt<1>
    output x : UInt<2>
    x <= UInt<2>(3)

??

ekiwi avatar Aug 24 '22 19:08 ekiwi

CIRCT just treats invalid values as 0, so this gets generated for your example:

module subword_2(
  input        c,
  output [1:0] x);

  assign x = c ? 2'h2 : 2'h1;
endmodule

It would be cool if there were smarter invalid value optimization that could simplify this though!

zyedidia avatar Aug 24 '22 20:08 zyedidia

It would be cool if there were smarter invalid value optimization that could simplify this though!

In general firrtl does some of these optimizations. (Even though, they can be a little bit problematic since they can break local invariants in the presence of invalid).

So for example:

circuit invalid_opt:
  module invalid_opt:
    input c : UInt<1>
    output x : UInt<2>

    x is invalid
    when c:
      x <= UInt(3)

is compiled to

circuit invalid_opt :
  module invalid_opt :
    input c : UInt<1>
    output x : UInt<2>

    x <= UInt<2>("h3")

ekiwi avatar Aug 24 '22 20:08 ekiwi

I decided to follow your lead and just plug in zeros for now.

ekiwi avatar Aug 24 '22 20:08 ekiwi

@zyedidia: What does your current implementation do with this circuit?

Did you manage to make the combinatorial loop checker accept this?

circuit m:
  module m:
    input y : UInt<1>
    output x : UInt<1>
    wire tmp : UInt<2>
    x <= not(tmp[1])
    tmp[1] <= not(tmp[0])
    tmp[0] <= y

Or this circuit which might be closer to something users might actually want to write:

circuit m:
  module m:
    input y : UInt<1>
    output x : UInt<1>
    wire tmp : UInt<4>
    x <= not(tmp[3])
    tmp[3:1] <= not(tmp[2:0])
    tmp[0] <= y

ekiwi avatar Aug 25 '22 18:08 ekiwi

This PR doesn't modify the combinational cycle checker, so those examples cause errors. If you disable the check with --firrtl-check-comb-cycles=false then it generates this for the first circuit:

module m(
  input  y,
  output x);

  assign x = y;
endmodule

and this for the second circuit:

module m(
  input  y,
  output x);

  wire [2:0] _GEN;
  assign _GEN = ~{_GEN[1:0], y};
  assign x = ~(_GEN[2]);
endmodule

I think the second one has a true combinational cycle (on tmp[2:1]). Perhaps you meant to have x <= not(tmp[1]) instead? In that case the two examples compile to the same thing.

zyedidia avatar Aug 25 '22 18:08 zyedidia

CIRCT just treats invalid values as 0, so this gets generated for your example:

Note: this isn't correct. CIRCT implements the same context-sensitive interpretation of invalid that the SFC uses. I had to do some excavation to discern what this is. It's documented here: https://circt.llvm.org/docs/Dialects/FIRRTL/RationaleFIRRTL/#interpretation-of-undefined-behavior. In test form: https://github.com/llvm/circt/blob/main/test/Dialect/FIRRTL/SFCTests/invalid-interpretations.fir

The only known divergence from the SFC is around SFC optimizations that explicitly check for self-connects to registers.

E.g., your example works as expected:

# cat Foo.fir && firtool Foo.fir -ir-fir | circt-translate -export-firrtl         
circuit invalid_opt:
  module invalid_opt:
    input c : UInt<1>
    output x : UInt<2>

    x is invalid
    when c:
      x <= UInt(3)
circuit invalid_opt :
  module invalid_opt :
    input c : UInt<1>
    output x : UInt<2>

    x <= UInt<2>(3) @[<stdin> 5:7]

seldridge avatar Aug 25 '22 19:08 seldridge

I think the second one has a true combinational cycle (on tmp[2:1]). Perhaps you meant to have x <= not(tmp[1]) instead? In that case the two examples compile to the same thing.

It shouldn't on the bit-level since tmp[3:1] <= not(tmp[2:0]) expands to:

tmp[3] <= not(tmp[2])
tmp[2] <= not(tmp[1])
tmp[1] <= not(tmp[0])

ekiwi avatar Aug 25 '22 19:08 ekiwi

Ah yes that's right.

zyedidia avatar Aug 25 '22 19:08 zyedidia

Ah yes that's right.

For now I think it is fine that this is rejected by the comb loop checker. However, I have a feeling that once people start using sub-word assignments, they will start complaining about us rejecting this kind of code.

ekiwi avatar Aug 25 '22 19:08 ekiwi

Here is a simple example with a CHIRRTL memory:

circuit Ram :
  module Ram :
    input clock : Clock
    input reset : UInt<1>
    input io : { addr : UInt<32>, wdata : UInt<8>}

    smem mem : UInt<8> [1024] @[Ram.scala 14:24]
    node _T = bits(io.addr, 9, 0)
    write mport MPORT = mem[_T], clock
    MPORT[1:0] <= io.wdata

Currently MFC throws a flow error on this: error: connect has invalid flow: the destination expression has source flow, expected sink or duplex flow. I'm a bit confused about this, because the bits should retain the flow of the underlying expression (duplex in this case). If I change the bits flow to duplex (just to see what happens), in this case it produces an initialization checking error during LowerFIRRTLTypes. This seems like a reasonable behavior, or should it do something different? In a more complex case with some masks, it produces an incorrect result (if I ignore the flow error). I can post that example too, but would like to understand what's going on here with the flow first.

zyedidia avatar Aug 30 '22 19:08 zyedidia

@zyedidia The flow checking is implemented as a verifier that is run after each pass. In this case, it is failing after the LowerCHIRRTL pass. It attaches the bits operation to the rdata port of the lowered memory (trimmed down):

%8 = "firrtl.mem"() {annotations = [], depth = 1024 : i64, name = "mem", nameKind = #firrtl<name_kind interesting_name>, portAnnotations = [[]], portNames = ["MPORT"], readLatency = 1 : i32, ruw = 0 : i32, writeLatency = 1 : i32} : () -> !firrtl.bundle<addr: uint<10>, en: uint<1>, clk: clock, rdata flip: uint<8>, wmode: uint<1>, wdata: uint<8>, wmask: uint<1>>
%12 = "firrtl.subfield"(%8) {fieldIndex = 3 : i32} : (!firrtl.bundle<addr: uint<10>, en: uint<1>, clk: clock, rdata flip: uint<8>, wmode: uint<1>, wdata: uint<8>, wmask: uint<1>>) -> !firrtl.uint<8>
%16 = "firrtl.bits"(%12) {hi = 1 : i32, lo = 0 : i32} : (!firrtl.uint<8>) -> !firrtl.uint<2>
"firrtl.strictconnect"(%16, %19) : (!firrtl.uint<2>, !firrtl.uint<2>) -> ()
Here is a more verbose output
firrtl.circuit "Ram"  {
  firrtl.module @Ram(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %io: !firrtl.bundle<addr: uint<32>, wdata: uint<8>>) {
    %0 = firrtl.subfield %io(1) : (!firrtl.bundle<addr: uint<32>, wdata: uint<8>>) -> !firrtl.uint<8>
    %1 = firrtl.subfield %io(0) : (!firrtl.bundle<addr: uint<32>, wdata: uint<8>>) -> !firrtl.uint<32>
    %mem = chirrtl.seqmem  Undefined  : !chirrtl.cmemory<uint<8>, 1024>
    %MPORT_data, %MPORT_port = chirrtl.memoryport Write %mem  {name = "MPORT"} : (!chirrtl.cmemory<uint<8>, 1024>) -> (!firrtl.uint<8>, !chirrtl.cmemoryport)
    %2 = firrtl.bits %MPORT_data 1 to 0 : (!firrtl.uint<8>) -> !firrtl.uint<2>
    %3 = firrtl.bits %1 9 to 0 : (!firrtl.uint<32>) -> !firrtl.uint<10>
    %_T = firrtl.node  %3  : !firrtl.uint<10>
    chirrtl.memoryport.access %MPORT_port[%_T], %clock : !chirrtl.cmemoryport, !firrtl.uint<10>, !firrtl.clock
    %4 = firrtl.tail %0, 6 : (!firrtl.uint<8>) -> !firrtl.uint<2>
    firrtl.strictconnect %2, %4 : !firrtl.uint<2>
  }
}

<stdin>:13:7: error: 'firrtl.strictconnect' op connect has invalid flow: the destination expression has source flow, expected sink or duplex flow
      firrtl.strictconnect %2, %4 : !firrtl.uint<2>
      ^
<stdin>:13:7: note: see current operation: "firrtl.strictconnect"(%16, %19) : (!firrtl.uint<2>, !firrtl.uint<2>) -> ()
<stdin>:8:12: note: the destination was defined here
      %2 = firrtl.bits %MPORT_data 1 to 0 : (!firrtl.uint<8>) -> !firrtl.uint<2>
           ^
// -----// IR Dump After LowerCHIRRTLPass Failed (firrtl-lower-chirrtl) //----- //
"firrtl.module"() ({
^bb0(%arg0: !firrtl.clock, %arg1: !firrtl.uint<1>, %arg2: !firrtl.bundle<addr: uint<32>, wdata: uint<8>>):
  %0 = "firrtl.invalidvalue"() : () -> !firrtl.uint<1>
  %1 = "firrtl.invalidvalue"() : () -> !firrtl.uint<8>
  %2 = "firrtl.constant"() {value = 1 : ui1} : () -> !firrtl.uint<1>
  %3 = "firrtl.invalidvalue"() : () -> !firrtl.clock
  %4 = "firrtl.constant"() {value = 0 : ui1} : () -> !firrtl.uint<1>
  %5 = "firrtl.invalidvalue"() : () -> !firrtl.uint<10>
  %6 = "firrtl.subfield"(%arg2) {fieldIndex = 1 : i32} : (!firrtl.bundle<addr: uint<32>, wdata: uint<8>>) -> !firrtl.uint<8>
  %7 = "firrtl.subfield"(%arg2) {fieldIndex = 0 : i32} : (!firrtl.bundle<addr: uint<32>, wdata: uint<8>>) -> !firrtl.uint<32>
  %8 = "firrtl.mem"() {annotations = [], depth = 1024 : i64, name = "mem", nameKind = #firrtl<name_kind interesting_name>, portAnnotations = [[]], portNames = ["MPORT"], readLatency = 1 : i32, ruw = 0 : i32, writeLatency = 1 : i32} : () -> !firrtl.bundle<addr: uint<10>, en: uint<1>, clk: clock, rdata flip: uint<8>, wmode: uint<1>, wdata: uint<8>, wmask: uint<1>>
  %9 = "firrtl.subfield"(%8) {fieldIndex = 0 : i32} : (!firrtl.bundle<addr: uint<10>, en: uint<1>, clk: clock, rdata flip: uint<8>, wmode: uint<1>, wdata: uint<8>, wmask: uint<1>>) -> !firrtl.uint<10>
  "firrtl.strictconnect"(%9, %5) : (!firrtl.uint<10>, !firrtl.uint<10>) -> ()
  %10 = "firrtl.subfield"(%8) {fieldIndex = 1 : i32} : (!firrtl.bundle<addr: uint<10>, en: uint<1>, clk: clock, rdata flip: uint<8>, wmode: uint<1>, wdata: uint<8>, wmask: uint<1>>) -> !firrtl.uint<1>
  "firrtl.strictconnect"(%10, %4) : (!firrtl.uint<1>, !firrtl.uint<1>) -> ()
  %11 = "firrtl.subfield"(%8) {fieldIndex = 2 : i32} : (!firrtl.bundle<addr: uint<10>, en: uint<1>, clk: clock, rdata flip: uint<8>, wmode: uint<1>, wdata: uint<8>, wmask: uint<1>>) -> !firrtl.clock
  "firrtl.strictconnect"(%11, %3) : (!firrtl.clock, !firrtl.clock) -> ()
  %12 = "firrtl.subfield"(%8) {fieldIndex = 3 : i32} : (!firrtl.bundle<addr: uint<10>, en: uint<1>, clk: clock, rdata flip: uint<8>, wmode: uint<1>, wdata: uint<8>, wmask: uint<1>>) -> !firrtl.uint<8>
  %13 = "firrtl.subfield"(%8) {fieldIndex = 4 : i32} : (!firrtl.bundle<addr: uint<10>, en: uint<1>, clk: clock, rdata flip: uint<8>, wmode: uint<1>, wdata: uint<8>, wmask: uint<1>>) -> !firrtl.uint<1>
  "firrtl.strictconnect"(%13, %4) : (!firrtl.uint<1>, !firrtl.uint<1>) -> ()
  %14 = "firrtl.subfield"(%8) {fieldIndex = 5 : i32} : (!firrtl.bundle<addr: uint<10>, en: uint<1>, clk: clock, rdata flip: uint<8>, wmode: uint<1>, wdata: uint<8>, wmask: uint<1>>) -> !firrtl.uint<8>
  "firrtl.strictconnect"(%14, %1) : (!firrtl.uint<8>, !firrtl.uint<8>) -> ()
  %15 = "firrtl.subfield"(%8) {fieldIndex = 6 : i32} : (!firrtl.bundle<addr: uint<10>, en: uint<1>, clk: clock, rdata flip: uint<8>, wmode: uint<1>, wdata: uint<8>, wmask: uint<1>>) -> !firrtl.uint<1>
  "firrtl.strictconnect"(%15, %0) : (!firrtl.uint<1>, !firrtl.uint<1>) -> ()
  %16 = "firrtl.bits"(%12) {hi = 1 : i32, lo = 0 : i32} : (!firrtl.uint<8>) -> !firrtl.uint<2>
  %17 = "firrtl.bits"(%7) {hi = 9 : i32, lo = 0 : i32} : (!firrtl.uint<32>) -> !firrtl.uint<10>
  %18 = "firrtl.node"(%17) {annotations = [], name = "_T", nameKind = #firrtl<name_kind interesting_name>} : (!firrtl.uint<10>) -> !firrtl.uint<10>
  "firrtl.strictconnect"(%9, %18) : (!firrtl.uint<10>, !firrtl.uint<10>) -> ()
  "firrtl.strictconnect"(%10, %2) : (!firrtl.uint<1>, !firrtl.uint<1>) -> ()
  "firrtl.strictconnect"(%11, %arg0) : (!firrtl.clock, !firrtl.clock) -> ()
  "firrtl.strictconnect"(%15, %4) : (!firrtl.uint<1>, !firrtl.uint<1>) -> ()
  %19 = "firrtl.tail"(%6) {amount = 6 : i32} : (!firrtl.uint<8>) -> !firrtl.uint<2>
  "firrtl.strictconnect"(%16, %19) : (!firrtl.uint<2>, !firrtl.uint<2>) -> ()
}) {annotations = [], parameters = [], portAnnotations = [], portDirections = 0 : i3, portNames = ["clock", "reset", "io"], portSyms = [], portTypes = [!firrtl.clock, !firrtl.uint<1>, !firrtl.bundle<addr: uint<32>, wdata: uint<8>>], sym_name = "Ram"} : () -> ()

This might help explain what should be happening: https://github.com/llvm/circt/blob/daddbc63b7e60156d048e361a41855bed629cc77/lib/Dialect/FIRRTL/Transforms/LowerCHIRRTL.cpp#L541-L547

To get this working, I think that you need to do two things:

  1. Update inferMemoryPortKind to "look through" bits operations.
  2. Add a visitExpr(BitsOp bits) that calls cloneSubindexOpForMemory

youngar avatar Aug 31 '22 05:08 youngar

During expand whens, if we form a mux with invalid and another value, we choose the other value. E.g. mux(p, v, invalid) => v. We found that if we apply this optimization in general, it started breaking cores, so we only apply it during expand whens as SFC does. I was trying to run a quick test to see if this was working as I might expect and found two test cases I think should produce the same result but do not.

Edit: There is a bunch of discussion about this above that I missed while I was on vacation. If we're just plugging in 0 for invalid, then maybe the first example is incorrect?

First example seems fine:

circuit Bits :
  module Bits :
    input p : UInt<1>
    output o : UInt<8>

    wire w : UInt<8>
    w is invalid
    when p:
      w[0] <= UInt<1>(1)
    o <= w
module Bits(
  input        p,
  output [7:0] o);

  assign o = 8'h1;
endmodule

Second one looks like the invalid is lowered to 0:

circuit Bits :
  module Bits :
    input p : UInt<1>
    output o : UInt<8>

    wire w : UInt<8>
    w is invalid
    when p:
      w[0] <= UInt<1>(1)
    else:
      w[0] is invalid
    o <= w
module Bits(
  input        p,
  output [7:0] o);

  assign o = {7'h0, p};
endmodule

youngar avatar Sep 02 '22 06:09 youngar

I think this behavior is caused by the invalid value optimizer seeing a bits of an invalid value, rather than the invalid value directly, so it doesn't apply the optimization. We could add a canonicalization within the pass transforming bits(invalid, hi, lo) -> invalid<hi-lo+1>, or change the mux optimization in flattenConditionalConnections to look through bits operations. Not sure which one is the better approach, but I think I'm leaning towards the canonicalization approach, what do you think?

zyedidia avatar Sep 02 '22 07:09 zyedidia

I'm not sure if this behavior is desirable, but I have made an implementation using the canonicalization approach here: https://github.com/zyedidia/circt/commit/0cd42feeebceb06e12a2b36fd96ef50dcc3bb76e. It also required implementing a special version of createOrFold<CatPrimOp> that specifically folds invalid values into a single invalid value, rather than a 0 constant.

zyedidia avatar Sep 02 '22 07:09 zyedidia

The lack of when optimization here is probably safe given that this is totally new code. However, it would be ideal, as @youngar states, to be consistent with the existing invalid interpretations. Namely, an invalid in a when path means "choose the other path unconditionally" while it means zero in most other contexts.

Not being consistent here means that designs that use a vector of bools will have different behavior from a uint with bit selects. These should be equivalent.

I also believe that the SFC interpretation @ekiwi wrote is now also wrong in the same way.

seldridge avatar Sep 02 '22 09:09 seldridge

Sounds good, I have added the optimization and a test for it.

zyedidia avatar Sep 02 '22 16:09 zyedidia

Status?

darthscsi avatar Apr 28 '23 18:04 darthscsi

I think if there is interest in having this feature, chipsalliance/firrtl-spec#26 needs to be merged first. I am not in a position to decide what (if anything) needs to change there though, or to perform the actual merge (I can rebase the update though if desired). If that PR is merged (or confirmed that it will be merged), then I think this can be rebased (it seems like ExpandWhens has been refactored a bit since this was opened, but hopefully nothing too drastic). I'm not very familiar with the new reference types (seems to be a source of some merge conflicts), so I'm not sure if the interaction with subword assignment would cause problems in the current proposal/implementation.

If there isn't consensus in favor of this feature/approach, or if significant changes are needed (warranting a redesign of the proposal/implementation), then we can close this PR and the spec PR and just refer back to them if a future version of subword assignment is proposed (or possibly re-use some code if applicable).

zyedidia avatar Apr 28 '23 19:04 zyedidia