Sudden explosion in inlining depth needed
Summary: I have some CLaSH code where if I start using some seemingly innocent, non-recursive utility function, suddenly the inlining limit required to synthesize it shoots up through the roof.
I have made a simplified, cut-down version of my code for this ticket. First, the function that I want to use is the following:
data Failure
= Underrun
| Overrun
deriving Show
data Buffer n dat = Buffer
{ bufferContents :: Vec n dat
, bufferLast :: Maybe (Index n)
deriving (Show, Generic, Undefined)
instance (KnownNat n, Default dat) => Default (Buffer n dat) where
def = Buffer (pure def) Nothing
remember :: (KnownNat n) => Buffer n dat -> dat -> Buffer n dat
remember Buffer{..} x = Buffer
{ bufferContents = replace bufferLast' x bufferContents
, bufferLast = Just bufferLast'
bufferLast' = maybe minBound (+ 1) bufferLast
newtype FetchM n dat m a = FetchM{ unFetchM :: ReaderT (Buffer n dat) (StateT (Maybe (Index n)) (ExceptT Failure m)) a }
deriving newtype (Functor, Applicative, Monad)
runFetchM :: (Monad m, KnownNat n) => Buffer n dat -> FetchM n dat m a -> m (Either Failure a)
runFetchM buf act = runExceptT $ evalStateT (runReaderT (unFetchM act) buf) Nothing
fetch :: (Monad m, KnownNat n) => FetchM n dat m dat
fetch = do
Buffer{..} <- FetchM ask
case bufferLast of
Nothing -> underrun
Just bufferLast -> do
idx <- FetchM get
when (maybe False (== maxBound) idx) overrun
when (maybe False (>= bufferLast) idx) underrun
let idx' = maybe minBound (+ 1) idx
FetchM $ put $ Just idx'
return $ bufferContents !! idx'
overrun = FetchM . lift . lift . throwE $ Overrun
underrun = FetchM . lift . lift . throwE $ Underrun
The "good" version of my program, which doesn't use the above FetchM
monad, and can be synthesized with a clash-inline-limit
of 50, no external dependencies:
{-# LANGUAGE RecordWildCards, TupleSections #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE ApplicativeDo #-}
{-# LANGUAGE GeneralizedNewtypeDeriving, DerivingStrategies #-}
module SpaceInvaders where
import Clash.Prelude hiding (lift, clkPeriod)
import Control.Arrow (first)
import Control.Monad.Reader
import Control.Monad.State hiding (state)
import Control.Monad.Writer as W
import Control.Monad.RWS
import Control.Monad.Trans.Class
import Control.Monad.Trans.Except
import Data.Monoid
newtype CPU i s o a = CPU{ unCPU :: ExceptT () (RWS i (Endo o) s) a }
deriving newtype (Functor, Applicative, Monad, MonadState s)
input :: CPU i s o i
input = CPU ask
abort :: CPU i s o a
abort = CPU $ throwE ()
runCPU :: (s -> o) -> CPU i s o () -> (i -> State s o)
runCPU mkDef cpu inp = do
s <- get
let (s', f) = execRWS (runExceptT $ unCPU cpu) inp s
put s'
def <- gets mkDef
return $ appEndo f def
type Value = Unsigned 8
type Addr = Unsigned 16
data Instr
= MOV Value
deriving (Eq, Ord, Show, Generic, Undefined)
fetchInstr :: (Monad m) => m (Unsigned 8) -> m Instr
fetchInstr fetch = do
b1 <- fetch
let b1'@(_ :> _ :> d2@r :> d1@p :> d0 :> s2 :> s1 :> s0 :> Nil) = bitCoerce b1 :: Vec 8 Bit
case b1' of
0 :> 0 :> _ :> _ :> _ :> 1 :> 1 :> 0 :> Nil -> MOV <$> fetch
0 :> 1 :> 1 :> 1 :> 0 :> 1 :> 1 :> 0 :> Nil -> return HLT
_ -> return NOP
data Phase
= Init
| Fetching (Buffer 3 Value)
deriving (Show, Generic, Undefined)
data CPUIn = CPUIn
{ cpuInMem :: Value
deriving (Show)
data CPUState = CPUState
{ phase :: Phase
, pc :: Addr
deriving (Show, Generic, Undefined)
initState :: CPUState
initState = CPUState
{ phase = Init
, pc = 0x0000
data CPUOut = CPUOut
{ cpuOutMemAddr :: Addr
, cpuOutMemWrite :: Maybe Value
deriving (Show)
defaultOut :: CPUState -> CPUOut
defaultOut CPUState{..} = CPUOut{..}
cpuOutMemAddr = pc
cpuOutMemWrite = Nothing
cpu :: CPU CPUIn CPUState CPUOut ()
cpu = do
CPUIn{..} <- input
CPUState{..} <- get
case phase of
Init -> goto $ Fetching def
Fetching buf -> do
buf' <- remember buf <$> do
setPC $ pc + 1
return cpuInMem
-- instr_ <- runFetchM buf' $ fetchInstr fetch
-- instr <- case instr_ of
-- Left Underrun -> goto (Fetching buf') >> abort
-- Left Overrun -> error "Overrun"
-- Right instr -> return instr
instr <- pure NOP
goto $ Fetching def
exec instr
exec NOP = return ()
exec instr = return () -- errorX $ show instr
goto ph = modify $ \s -> s{ phase = ph }
setPC pc = modify $ \s -> s{ pc = pc }
{-# NOINLINE topEntity #-}
{-# ANN topEntity
{ t_name = "SpaceInvaders"
, t_inputs =
[ PortName "CLK_25MHZ"
, PortName "RESET"
, t_output = PortName "VIDEO"
}) #-}
:: Clock System Source
-> Reset System Asynchronous
-> Signal System Value
topEntity = exposeClockReset board
board = blockRam (pure 0x00 :: Vec VidSize Value) pixAddr mainBoard
pixAddr = register (minBound :: Index VidSize) $ mux (pixAddr .==. pure maxBound) (pure minBound) (pixAddr + 1)
:: (HiddenClockReset domain gated synchronous)
=> Signal domain (Maybe (Index VidSize, Value))
mainBoard = register Nothing $ fmap (first fromIntegral) <$> vidWrite
cpuOut = mealyState (runCPU defaultOut cpu) initState cpuIn
memAddr = cpuOutMemAddr <$> cpuOut
memWrite = packWrite memAddr $ cpuOutMemWrite <$> cpuOut
vidWrite = do
w <- memWrite
pure $ case w of
Just (a, d) | 0x2400 <= a && a < 0x4000 -> Just (truncateB @_ @13 (a - 0x2400), d)
_ -> Nothing
progROM addr = unpack <$> romFilePow2 "image.hex" (truncateB @_ @13 <$> addr)
cpuIn = do
cpuInMem <- progROM memAddr
pure CPUIn{..}
packWrite :: (Applicative f) => f a -> f (Maybe b) -> f (Maybe (a, b))
packWrite addr x = sequenceA <$> ((,) <$> addr <*> x)
type VidX = 256
type VidY = 224
type VidSize = VidX * VidY `Div` 8
mealyState :: (HiddenClockReset domain gated synchronous, Undefined s)
=> (i -> State s o) -> s -> (Signal domain i -> Signal domain o)
mealyState f s0 x = mealy step s0 x
step s x = let (y, s') = runState (f x) s in (s', y)
And then the "bad" version is the same, except the usage of FetchM
is enabled in cpu
cpu :: CPU CPUIn CPUState CPUOut ()
cpu = do
CPUIn{..} <- input
CPUState{..} <- get
case phase of
Init -> goto $ Fetching def
Fetching buf -> do
buf' <- remember buf <$> do
setPC $ pc + 1
return cpuInMem
instr_ <- runFetchM buf' $ fetchInstr fetch
instr <- case instr_ of
Left Underrun -> goto (Fetching buf') >> abort
Left Overrun -> error "Overrun"
Right instr -> return instr
goto $ Fetching def
exec instr
So then the problem with the "bad" version is that even in this small repro case, it takes an inlining depth of 200-250 (200 is not enough, 250 is enough) to properly eliminate intermediate values of function type. And in the real CLaSH code that I intend to use this in, I can't push the inlining depth high enough without the synthesizer running out of memory on my 24Gb notebook.
This behaviour is triggered by a data type that has a field with a function type, in this case the Endo
in the CPU
data type. Since functions cannot be (trivially) represented by a finite number number of bits, Clash will inline all values of a data type that has function (and other "non-representable) types as fields.
There's no (quick) fix to handle this in a non-exponential manner. I'm researching some new compile approaches that are (hopefully) significantly faster and consume less memory. Until then, the only work around is to avoid data types with fields that have a function type.
Thanks for looking into this, it was really blocking me. I will look into alternative implementations for CPUOut
Two comments though:
Endo CPUOut
here is only used for control flow, basically. There is never such a value escaping the pure function inside themealy
call. Can we hope to one day use this property to avoid the present problem? -
I wonder if GRIN-like full program defunctionalization would provide an easy path to function-valued signals, opening the door to completely new approaches in compilation.
I believe we looked into this before and concluded that, while it would work, the resulting HDL would look nothing like the source code thus frustrating debugging attempts.