Soss.jl icon indicating copy to clipboard operation
Soss.jl copied to clipboard

Make `iid` take a shape instead of length

Open cscherrer opened this issue 4 years ago • 1 comments

@vasishth asked on Twitter about implementation of models like these. The "varying intercepts, varying slopes" model in Listing 6 would be something like this:

# varying intercepts, varying slopes
vivs = @model N,so,subj,item  begin
    σu ~ HalfCauchy() |> iid(2)
    σw ~ HalfCauchy() |> iid(2)
    σe ~ HalfCauchy()
    u ~ Normal(0, σu) |> iid(2,J)
    w ~ Normal(0, σw) |> iid(2,J)
    rt ~ For(1:N) do i 
        μ = (β[1] 
            + u[1,subj[i]] 
            + w[1,item[i]]
            + so[i] * (β[2] 
                    + u[1,subj[2]] 
                    + w[1,item[2]]
                    )
            )
        LogNormal(μ,σe)
    end
end

Currently this doesn't work, because iid only takes a length. Let's change this by removing the Int shape methods and making this always a Tuple. For example, x ~ dist |> iid(3) should produce iid((3,),dist).

Down the road, we'll have iid methods to produce an iterator. We can either make the first argument a Union type or leave it unspecified.

cscherrer avatar Sep 27 '19 15:09 cscherrer

Got this working with For, still need to check iid:

julia> using Soss

julia> vivs = @model so,subj,item  begin
           N = length(so)
           J = maximum(subj)
           K = maximum(item)
           σu ~ HalfCauchy() |> iid(2)
           σw ~ HalfCauchy() |> iid(2)
           σe ~ HalfCauchy()
           u ~ For(1:2, 1:J) do i,j
                   Normal(0.0, σu[i]) 
               end
           w ~ For(1:2, 1:J) do i,j
                   Normal(0.0, σw[i]) 
               end
           β ~ Cauchy() |> iid(2)
           rt ~ For(1:N) do i 
               μ = (β[1] 
                   + u[1,subj[i]] 
                   + w[1,item[i]]
                   + so[i] * (β[2] 
                           + u[2,subj[i]] 
                           + w[2,item[i]]
                           )
                   )
               LogNormal(μ,σe)
           end
       end;

julia> N = 10;

julia> so = rand(16);

julia> subj = repeat(1:4, inner=4);

julia> item = repeat(1:4, outer=4);

julia> truth = rand(vivs(so=so,subj=subj, item=item))
(so = [0.05705834099677576, 0.8302100111891892, 0.9815461581204299, 0.4905848664186814, 0.7840078908223014, 0.26409912550822834, 0.6118625189292837, 0.9810481943748384, 0.7911981601876461, 0.6880797669217511, 0.09416783465794443, 0.49754503634465475, 0.18718713168785528, 0.40802127078364125, 0.1745261279471364, 0.7703638255385965], subj = [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4], item = [1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4], N = 16, K = 4, J = 4, σw = [0.35949337840736484, 10.82124873510444], σu = [8.519789607704864, 1.2442508432497936], σe = 5.510426090810765, β = [7.495790844246555, 0.11569662766768342], w = [0.3817578211856579 0.15484291459555616 0.07545678605921653 -0.15263552378673742; -10.478401991528616 22.1154149324285 0.08079165426702911 -20.96468984392859], u = [-3.5976262589868964 5.896856499624867 11.111045257755764 -3.322043209333268; 0.3194588653394458 2.0253587470686223 1.631192007751775 -1.1709162421913692], rt = [1.4676131900996339, 535125.1188398163, 0.3664085021422927, 0.013748011025174207, 1089.2523011737853, 1.9403836678389742e6, 1.3824131315950798e6, 5.280738317890048, 176.3715912784793, 7.908130260925237e19, 9.155296087038978e8, 5969.7766219371915, 0.005648596634112367, 0.9941209093280838, 228.35328206698105, 7.986975771856124e-9])

julia> dynamicHMC(vivs(so=so,subj=subj, item=item), (rt = truth.rt,))
1000-element Array{NamedTuple{(:σw, :σu, :σe, :β, :w, :u),Tuple{Array{Float64,1},Array{Float64,1},Float64,Array{Float64,1},Array{Float64,2},Array{Float64,2}}},1}:
 (σw = [0.5798238385221418, 13.503325362861986], σu = [8.268653576728408, 0.8913758048343248], σe = 5.85220048952389, β = [1.0468628358017023, 0.16123669660091525], w = [-0.46411955039390695 -0.17834756654093148 -0.3667869947411047 -0.6267518992642751; -0.7705948803399764 26.479447421878938 6.2706746452139255 -14.042468047180545], u = [-5.9292490023414715 4.4628749032304205 9.665143808172513 -7.755218059048623; -0.1658039059901143 -0.7200216460709323 0.4935542671519438 -0.3873579875303357])
 (σw = [0.8139014158361357, 7.923530933536202], σu = [9.97402194246488, 0.9392401721164817], σe = 6.106099463934436, β = [0.21018297271210148, -0.16176741124245692], w = [-0.6724055581654802 0.6197657855711941 -0.5671065223537874 0.8934343622376998; -1.5394750367222985 14.397463687393508 7.3640671259709185 -7.766451373435515], u = [-2.1301204140174947 6.518966952869998 16.123606725116574 -5.280820333348136; -0.9109214578484315 -0.2562944859301565 0.9031482120668543 -0.2749656167262518])    
 (σw = [0.7775128331960959, 7.777069517187342], σu = [11.235520990712503, 0.8891987693255932], σe = 8.272417661882026, β = [-0.0760614100199653, -0.3295776452726164], w = [0.6449146701318197 -0.006190933014135869 0.3669104515712563 -0.49659692412650686; 0.3669523389238277 19.348702947293653 6.978206960600503 -3.723812002469496], u = [-1.82685237453415 9.744428916689282 15.830505567609313 -6.0239521367658675; 1.019173849524488 0.8129439240218211 -0.6770602798944121 0.23853629453182734])     
 (σw = [1.2875228249050883, 5.623689454618413], σu = [11.427184125801945, 0.41226316817871705], σe = 6.8790783687886385, β = [-1.3070981867624947, 0.18440069873789383], w = [-1.1464416134534534 0.4241705901722263 1.4865114006915325 -1.6607033144321552; -0.31050079699833893 7.12954988151638 1.9517033467608593 -11.513073204791846], u = [6.00839185835003 10.860757621355841 16.826140738285552 2.4049779492060845; 0.8296207391955714 0.6711002505317943 0.25139968592117967 0.20594169557402234])    
 (σw = [0.8883264435481624, 8.862508021139801], σu = [14.589147477657018, 0.6285659460121209], σe = 10.888932716552286, β = [-0.17628258690148774, 0.9151788775953427], w = [0.34958121782627233 0.32727503361438176 -1.3386168167547396 -0.15288780498235027; 2.3885391106984755 11.81295700376109 3.077653926696049 -6.043868715154859], u = [7.805816773296412 11.304324183594536 18.626916108193594 0.4015581528327861; -0.44921015490790345 -0.5719526770875699 0.2161170547903416 -0.39370648531348523]) 
 (σw = [0.5104617164729187, 3.6726614781238967], σu = [7.660040212814474, 0.2075831657020835], σe = 9.154077230829493, β = [-1.4616025317479766, -1.9746061973928304], w = [-0.23972986891949138 0.5570632748127919 1.323458686743813 0.013252790642741652; -1.8023620333804047 6.146671583679789 -0.3685326884074048 -7.302326268444786], u = [8.047598619892675 10.724623342728089 17.43706426103348 0.8591319128615474; 0.13341950216333578 0.477108065943323 -0.0890112554589888 0.05325231278074517])     
 (σw = [0.5104617164729187, 3.6726614781238967], σu = [7.660040212814474, 0.2075831657020835], σe = 9.154077230829493, β = [-1.4616025317479766, -1.9746061973928304], w = [-0.23972986891949138 0.5570632748127919 1.323458686743813 0.013252790642741652; -1.8023620333804047 6.146671583679789 -0.3685326884074048 -7.302326268444786], u = [8.047598619892675 10.724623342728089 17.43706426103348 0.8591319128615474; 0.13341950216333578 0.477108065943323 -0.0890112554589888 0.05325231278074517])     
 (σw = [0.5104617164729187, 3.6726614781238967], σu = [7.660040212814474, 0.2075831657020835], σe = 9.154077230829493, β = [-1.4616025317479766, -1.9746061973928304], w = [-0.23972986891949138 0.5570632748127919 1.323458686743813 0.013252790642741652; -1.8023620333804047 6.146671583679789 -0.3685326884074048 -7.302326268444786], u = [8.047598619892675 10.724623342728089 17.43706426103348 0.8591319128615474; 0.13341950216333578 0.477108065943323 -0.0890112554589888 0.05325231278074517])     
 (σw = [0.5104617164729187, 3.6726614781238967], σu = [7.660040212814474, 0.2075831657020835], σe = 9.154077230829493, β = [-1.4616025317479766, -1.9746061973928304], w = [-0.23972986891949138 0.5570632748127919 1.323458686743813 0.013252790642741652; -1.8023620333804047 6.146671583679789 -0.3685326884074048 -7.302326268444786], u = [8.047598619892675 10.724623342728089 17.43706426103348 0.8591319128615474; 0.13341950216333578 0.477108065943323 -0.0890112554589888 0.05325231278074517])     
 (σw = [0.5104617164729187, 3.6726614781238967], σu = [7.660040212814474, 0.2075831657020835], σe = 9.154077230829493, β = [-1.4616025317479766, -1.9746061973928304], w = [-0.23972986891949138 0.5570632748127919 1.323458686743813 0.013252790642741652; -1.8023620333804047 6.146671583679789 -0.3685326884074048 -7.302326268444786], u = [8.047598619892675 10.724623342728089 17.43706426103348 0.8591319128615474; 0.13341950216333578 0.477108065943323 -0.0890112554589888 0.05325231278074517])     
 (σw = [0.5104617164729187, 3.6726614781238967], σu = [7.660040212814474, 0.2075831657020835], σe = 9.154077230829493, β = [-1.4616025317479766, -1.9746061973928304], w = [-0.23972986891949138 0.5570632748127919 1.323458686743813 0.013252790642741652; -1.8023620333804047 6.146671583679789 -0.3685326884074048 -7.302326268444786], u = [8.047598619892675 10.724623342728089 17.43706426103348 0.8591319128615474; 0.13341950216333578 0.477108065943323 -0.0890112554589888 0.05325231278074517])     
 (σw = [0.5104617164729187, 3.6726614781238967], σu = [7.660040212814474, 0.2075831657020835], σe = 9.154077230829493, β = [-1.4616025317479766, -1.9746061973928304], w = [-0.23972986891949138 0.5570632748127919 1.323458686743813 0.013252790642741652; -1.8023620333804047 6.146671583679789 -0.3685326884074048 -7.302326268444786], u = [8.047598619892675 10.724623342728089 17.43706426103348 0.8591319128615474; 0.13341950216333578 0.477108065943323 -0.0890112554589888 0.05325231278074517])     
 (σw = [0.5027577170428184, 4.602768360175987], σu = [7.8451203096252256, 0.20592410867008884], σe = 8.865535366668338, β = [-1.4492861259232834, -1.8890810001493468], w = [-0.23028299459407264 0.4186209098581454 1.0055905441844013 0.10138978258876338; -1.6710794074787807 6.826277791558207 -0.1351299408331651 -7.405832909684176], u = [7.799204049654453 10.585646423723496 17.182947336070768 0.918742980230471; -0.07173850401523499 0.0577729569340773 -0.07302871209541824 0.1043416860494665])  
 (σw = [0.5027577170428184, 4.602768360175987], σu = [7.8451203096252256, 0.20592410867008884], σe = 8.865535366668338, β = [-1.4492861259232834, -1.8890810001493468], w = [-0.23028299459407264 0.4186209098581454 1.0055905441844013 0.10138978258876338; -1.6710794074787807 6.826277791558207 -0.1351299408331651 -7.405832909684176], u = [7.799204049654453 10.585646423723496 17.182947336070768 0.918742980230471; -0.07173850401523499 0.0577729569340773 -0.07302871209541824 0.1043416860494665])  
 (σw = [0.5027577170428184, 4.602768360175987], σu = [7.8451203096252256, 0.20592410867008884], σe = 8.865535366668338, β = [-1.4492861259232834, -1.8890810001493468], w = [-0.23028299459407264 0.4186209098581454 1.0055905441844013 0.10138978258876338; -1.6710794074787807 6.826277791558207 -0.1351299408331651 -7.405832909684176], u = [7.799204049654453 10.585646423723496 17.182947336070768 0.918742980230471; -0.07173850401523499 0.0577729569340773 -0.07302871209541824 0.1043416860494665])  
 (σw = [0.5027577170428184, 4.602768360175987], σu = [7.8451203096252256, 0.20592410867008884], σe = 8.865535366668338, β = [-1.4492861259232834, -1.8890810001493468], w = [-0.23028299459407264 0.4186209098581454 1.0055905441844013 0.10138978258876338; -1.6710794074787807 6.826277791558207 -0.1351299408331651 -7.405832909684176], u = [7.799204049654453 10.585646423723496 17.182947336070768 0.918742980230471; -0.07173850401523499 0.0577729569340773 -0.07302871209541824 0.1043416860494665])  
 ⋮                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             
 (σw = [2.022146554090919, 1.1581526094817245], σu = [6.99827981003005, 0.9185324040424694], σe = 10.61002381313232, β = [3.077826767417711, 0.6956698511778626], w = [1.5606551630753733 2.728184195925836 0.3922524152494362 -2.9494657602990824; 0.446683008660875 1.372557933798523 -1.1466438328286608 1.6434108124743116], u = [0.4972054522836434 1.5084135375053733 15.86086535347678 -8.9055878841549; 0.4866024735340556 0.9668305394898432 -1.0015867054529075 0.2953978762005536])                 
 (σw = [3.3019347205924285, 2.3462779857814695], σu = [9.092289544294083, 0.980758772423812], σe = 6.00411916176837, β = [0.9989913827111327, 1.1761583155039066], w = [-2.4771174681027595 5.841886043751172 0.042475724734855363 -2.3267270279137438; 0.010991375715552365 3.43535888276596 1.4289676988453608 -3.558215975285553], u = [0.7376621988235473 8.225108770937 15.985302540237067 -5.197520398511353; 1.3041045135101097 0.4959365845795223 -0.647475725051934 1.0001447903955831])              
 (σw = [3.6491772287331528, 2.0655590819513137], σu = [6.344543119723851, 1.0501701478136258], σe = 11.13881769718889, β = [1.7061136669088646, -0.3612351207656136], w = [2.2955603107484435 3.8163549103985503 6.010188711105561 -0.12161271159915539; -0.27307509019457826 0.92483677251165 0.27184906149546945 2.706994511542514], u = [-8.868067421450757 -2.978251464305822 6.985575094772051 -4.447027487405508; 1.3156919212606257 -0.07386541038418391 0.2193328592429148 1.1738575588682245])        
 (σw = [3.76350477565868, 1.5621491557758536], σu = [5.02810028817714, 0.8115424277940386], σe = 7.648125398789507, β = [1.1643029767309596, -0.24175396505808733], w = [3.7181690067601294 5.530670946730926 8.537899438275094 -1.307017938982401; -0.02684355591277817 3.147235489350998 1.1778512209145007 2.3844024613776043], u = [-7.895494443989046 -3.2612866614951295 7.3948505680477865 -6.527769535705547; 0.14183865335050233 0.035225000110169116 0.32149829915968525 -1.391441621763514])        
 (σw = [7.428133873319701, 1.9095101318661014], σu = [5.0074124228068335, 0.5306364843099903], σe = 15.241135423200548, β = [-0.23734150688497735, 1.9430833280249726], w = [2.599604377429512 8.5172121314765 6.489210784614295 3.310538959797749; 1.267377171410223 2.552242635794601 -0.6502570749617892 -2.2190505185847513], u = [-4.203351871031629 -4.779395994028225 5.5896684869822435 -5.602353627114084; 0.2801248048206726 -0.19827603335838218 0.36091943316882225 1.0188467384625859])           
 (σw = [5.563819079392282, 1.1001558443488748], σu = [10.133317498436305, 0.4512142816360848], σe = 6.879435972703339, β = [0.44789225374125063, 2.657239340374926], w = [2.6440726245343273 12.81341840200512 3.0656590787713083 2.482095809393285; -1.4500858881062302 2.217053576492654 -0.038509822468298205 1.5357544596084218], u = [-8.985391537512912 -3.5562822129347946 9.659969861264946 -3.7958868724105472; -0.26954054575898057 -0.05752040995542995 -0.17378738445632605 0.011288717809376347]) 
 (σw = [7.887988617101772, 1.129022594110785], σu = [6.264394547975936, 0.2799869160595212], σe = 9.937028134023524, β = [0.15572453369925301, 1.3371143459548989], w = [2.6539824713478266 15.288114558269976 3.6218921883667083 -0.9380602917168528; 1.2943807151577886 1.7555866072345783 0.620833644440374 -1.8978593284512222], u = [-5.3229681098951405 -2.9858383923128775 9.80525041010234 -11.044561817022258; 0.006215528384256042 -0.31673139021600716 -0.1762300744152051 0.1521326702157078])     
 (σw = [4.922429713044198, 4.964241263916502], σu = [11.118887512101587, 0.7727014918601854], σe = 6.8480371150291415, β = [1.1008352969756008, 1.5580857971760318], w = [1.5420926951318357 8.905406110812631 5.817055432840201 -1.329229202821511; 3.5188916131521113 -1.1584533369333476 0.8922587671414995 -1.4226551544863906], u = [-2.8502680232882414 -4.388230337028525 17.414813971447806 -8.851411774134375; -0.13357513465824797 0.36715944424928465 0.15803293095132415 0.27627056193438954])     
 (σw = [3.7277737781726894, 5.13379681720589], σu = [8.776694676958426, 0.3237567839752413], σe = 7.2036078681023685, β = [1.5557454538250146, 0.5015095574029671], w = [1.0080619934590536 10.35373573807001 8.216748219500241 -0.49914588307896035; 4.7239029501006105 -4.588430360207598 1.0025513399061834 -2.3466256857667926], u = [-3.680803140989154 -1.879424008392054 12.614477851911273 -9.744805863348422; 0.25375954202687356 0.0030830919004767032 -0.008905229564939027 0.42979625376026864])   
 (σw = [6.185638616913895, 3.208808847553807], σu = [6.912567687726786, 0.34680127190484833], σe = 7.0255474117739904, β = [-0.1691994629826809, 0.4142571321065145], w = [-0.5500522269887227 10.840554236382568 6.97046669505481 -0.30408377929945657; 6.199034505840589 -1.1620184618216236 0.25278438073313875 -2.0517479968578662], u = [-6.664605031970424 0.4438190312190164 15.431291928650325 -11.56405432917367; 0.31562001998118444 -0.024331073821040834 -0.3134651759981899 0.28368188919594023]) 
 (σw = [6.301260277575661, 3.4669059443565033], σu = [6.860761280907575, 0.446387130496555], σe = 8.412736694144868, β = [0.8414427121172452, 0.6440022186830215], w = [-0.8082050297393621 12.638676294181783 7.836598280509722 0.5626935494149305; 5.364926300797164 2.83908752822915 -0.2568339110330531 -1.092376436211453], u = [-7.404575918936336 1.2174211678232663 16.45336990727346 -12.75749663159704; -0.09404986926563613 -0.1755727956930867 -0.014316911222656992 0.4113341582813656])          
 (σw = [11.575877220559505, 4.448783626723499], σu = [10.996344109495915, 0.413378559057261], σe = 8.310522818869401, β = [1.6892627911177125, -0.8020847184288508], w = [4.480939721883063 10.049462132723887 6.545579434429497 -7.777665058202643; -8.532299087913273 0.3812664162711575 -1.7041562313120864 -1.4243678056895936], u = [1.9988125193608337 9.109781682566469 12.186501831923687 -11.252895595771879; 0.18026823412283904 0.46466475961308146 -0.07436967929599547 -0.31218269729033665])     
 (σw = [9.06428592842808, 7.183365775795522], σu = [12.29216833612979, 1.2935935976389876], σe = 4.577975863715498, β = [-0.34502226449480056, 0.6977665439969044], w = [4.22753695397991 10.56604762750238 3.4076300540192204 -4.076914774021226; -8.909784702811175 10.447011910155217 3.4573154361088334 2.13670709004704], u = [-4.403931007210874 3.0543395020373296 18.397433889764354 -6.014999331327232; 0.1508829658069909 -0.3816122076126016 -1.578197751162008 -0.41627711182158544])              
 (σw = [11.286273170948215, 5.166888859560457], σu = [7.354491406801909, 1.1330354141311632], σe = 5.6036727263519275, β = [2.6187998793583342, 0.5883451624489027], w = [-2.36180729284604 4.727182959246244 8.58373222944016 -5.190579949934731; -1.16299227267625 5.916988827683684 -11.373554464328844 -7.487258479656289], u = [1.3652586719003372 11.5212644357808 10.746476401679043 -5.1962399696420185; -0.4545338714427359 0.9392434449717701 -0.38936623008069265 -2.1471024380568005])             
 (σw = [7.089490194283234, 9.138480611242867], σu = [6.833182867848835, 0.7738304059647343], σe = 9.024512233125364, β = [3.0352027491776656, 0.8271067000284571], w = [-2.494209297998554 5.108427345235088 8.199294916936852 -5.338664272231864; -1.267837237003199 1.6743756609013019 -8.971474707569916 -7.803610137801066], u = [1.8637001832250473 8.62739134541818 8.683380555017857 -3.4227065843932687; 0.29977602128286746 0.37354413844645545 0.882645465955023 0.5063873484692758])                
 (σw = [2.5495268916746276, 55.118841133694936], σu = [4.210033824514281, 2.436929125670833], σe = 7.032689351414416, β = [1.6556783067208374, 0.3349987636436081], w = [-3.027758671688709 -4.214215084273659 0.9479705454295031 1.5122223948584794; 1.3235329515345386 38.4527505879667 5.072394895894986 -1.8125723798223707], u = [-3.3848374320634513 4.6254998061233366 1.1620190007182056 -8.375304940402666; -2.215828694874676 -1.1940385916567051 -0.20513370874827477 -1.8131171257355776])         

NOTE: The result of dynamicHMC in this case doesn't yet work with particles, so that's another issue

cscherrer avatar Oct 01 '19 20:10 cscherrer