keras-cn
keras-cn copied to clipboard
keras中Masking+LSTM后面接一个TimeDistributedDense,Dense输出结果没有mask
代码部分:
language_model = Sequential()
language_model.add(Embedding(vocab_size+2, EMBEDDING_DIM, mask_zero=True, weights=[embedding_matrix] ,name='embedding_layer'))
language_model.add(Masking(mask_value=PADDING,name='masking_layer'))
language_model.add(LSTM(output_dim=OTHERS, return_sequences=True,name='gru_layer'))
language_model.add(TimeDistributedDense(OUTPUT_TYPES_NUM,activation='softmax',name='softmax_layer'))
测试用例:
x_test=np.asarray([[0,0,0,2,3,4,5,6],[0,0,0,0,9,8,3,4],[0,0,0,0,0,9,8,3]])
Softmax 的输出结果: Dense Results: (3L, 8L, 10L) [[[ 0.10432124 0.09956408 0.09805973 0.10185226 0.10514925 0.09005494 0.09369726 0.10397689 0.10004393 0.10328046] [ 0.10432124 0.09956408 0.09805973 0.10185226 0.10514925 0.09005494 0.09369726 0.10397689 0.10004393 0.10328046] [ 0.10432124 0.09956408 0.09805973 0.10185226 0.10514925 0.09005494 0.09369726 0.10397689 0.10004393 0.10328046] [ 0.10468332 0.09065267 0.10465907 0.10914598 0.09866636 0.10089909 0.09287892 0.11226305 0.09915846 0.08699314] [ 0.10090157 0.1001429 0.10266486 0.10592359 0.08957883 0.1033702 0.10338669 0.09631021 0.09575413 0.10196707] [ 0.1009273 0.10819422 0.09471504 0.09420478 0.10287543 0.09454429 0.08861585 0.10622577 0.10038092 0.10931643] [ 0.0961623 0.10527957 0.09671341 0.09006541 0.10609487 0.09137738 0.08720418 0.1061098 0.1045992 0.11639385] [ 0.09365714 0.09961554 0.10184344 0.09357952 0.10447578 0.09096033 0.08860005 0.11356424 0.10304262 0.11066134]]
[[ 0.10432124 0.09956408 0.09805973 0.10185226 0.10514925 0.09005494 0.09369726 0.10397689 0.10004393 0.10328046] [ 0.10432124 0.09956408 0.09805973 0.10185226 0.10514925 0.09005494 0.09369726 0.10397689 0.10004393 0.10328046] [ 0.10432124 0.09956408 0.09805973 0.10185226 0.10514925 0.09005494 0.09369726 0.10397689 0.10004393 0.10328046] [ 0.10432124 0.09956408 0.09805973 0.10185226 0.10514925 0.09005494 0.09369726 0.10397689 0.10004393 0.10328046] [ 0.08880188 0.11206884 0.08541028 0.09300748 0.11226043 0.08611012 0.102557 0.11383081 0.10294707 0.10300611] [ 0.10394291 0.09647608 0.10178506 0.10291336 0.10539917 0.0867321 0.09617651 0.10704685 0.09978233 0.09974565] [ 0.0968347 0.10330293 0.10114764 0.09074623 0.09358761 0.09601314 0.10900775 0.10108642 0.10264133 0.10563225] [ 0.09343617 0.10481153 0.09716355 0.09031702 0.10999709 0.09630899 0.09260204 0.10974768 0.1030058 0.10261014]]
[[ 0.10432124 0.09956408 0.09805973 0.10185226 0.10514925 0.09005494 0.09369726 0.10397689 0.10004393 0.10328046] [ 0.10432124 0.09956408 0.09805973 0.10185226 0.10514925 0.09005494 0.09369726 0.10397689 0.10004393 0.10328046] [ 0.10432124 0.09956408 0.09805973 0.10185226 0.10514925 0.09005494 0.09369726 0.10397689 0.10004393 0.10328046] [ 0.10432124 0.09956408 0.09805973 0.10185226 0.10514925 0.09005494 0.09369726 0.10397689 0.10004393 0.10328046] [ 0.10432124 0.09956408 0.09805973 0.10185226 0.10514925 0.09005494 0.09369726 0.10397689 0.10004393 0.10328046] [ 0.08880188 0.11206884 0.08541028 0.09300748 0.11226043 0.08611012 0.102557 0.11383081 0.10294707 0.10300611] [ 0.10394291 0.09647608 0.10178506 0.10291336 0.10539917 0.0867321 0.09617651 0.10704685 0.09978233 0.09974565] [ 0.0968347 0.10330293 0.10114764 0.09074623 0.09358761 0.09601314 0.10900775 0.10108642 0.10264133 0.10563225]]]
最终结果: [[4 4 4 7 3 9 9 7] [4 4 4 4 7 7 6 4] [4 4 4 4 4 7 7 6]]
此外,有个比较奇怪的发现,输入必须的Padding必须加在pre位置上,求大大门指导
@jxwb088047 老实讲,对Masking这些东西不太熟,没怎么用过,这里先提醒你一下是否正确使用了Masking:
- 当输入信号在某个时间步上的所有值都等于给定的Masking值时,该信号才会被屏蔽。
从你的测试看来,你的输入信号在任何一个时间步上都不是一个全同信号,因此不管你的值设为多少,都不可能满足上面的条件,因此不会被mask。
@MoyanZitto Sorry,没贴(Softmax层的)上一层的输出,已修改,请继续指正 GRU Results: (3L, 8L, 21L) [[[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [ 0.00840368 -0.10805743 -0.08397113 -0.01513356 0.02580945 0.08339522 -0.07102467 -0.00039871 -0.02749132 -0.06340679 -0.04160816 0.02557643 -0.06138108 0.08559444 0.07123405 -0.06140784 -0.00829985 -0.00838125 0.04480657 0.04952285 -0.02105066] [-0.07398035 0.03058206 0.02635643 -0.06123957 -0.00590142 -0.09541193 -0.10703349 -0.02872061 0.13137171 0.08835198 0.07901608 0.12995203 -0.07013361 -0.15949027 -0.08306988 0.01631884 0.07318716 -0.06081156 -0.12927294 0.02352616 0.13821183] [-0.05236683 0.03891416 -0.04697997 -0.04320637 -0.06870133 0.01759486 -0.0700565 -0.07289717 0.06609842 0.07209982 0.13745363 0.07786325 -0.00670502 -0.10768702 -0.11650184 -0.02739821 0.15195981 -0.03856091 -0.02251505 0.01354269 0.20734778] [-0.19518134 0.18191135 0.1678455 0.03040497 0.06271093 -0.1636627 -0.00490535 0.12183905 0.09504748 0.14254275 0.28083947 -0.09219973 -0.04553385 -0.09144899 -0.10645168 0.09594533 0.01316203 0.21411076 -0.15279385 -0.00746152 0.17217697] [-0.19983874 0.2573593 0.22631709 -0.15415692 -0.09337292 -0.15152982 -0.17283896 0.18361479 -0.06854388 0.19458856 0.21257156 0.05420392 -0.13089052 0.0174079 -0.14144853 0.21623498 -0.09807283 0.17503688 -0.10952377 -0.11004093 0.12416358]]
[[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [ 0.01375124 -0.12183258 0.02387197 -0.06606983 0.0128374 0.0583975 -0.15242508 0.25402677 0.08650441 -0.12012455 -0.1282458 -0.12729955 -0.18978614 0.05072499 0.13404976 -0.08693483 -0.05932525 0.0042815 -0.10118926 0.06298935 -0.11105874] [ 0.01753891 -0.26056546 -0.10591014 -0.07972372 0.14199689 0.01462964 -0.19376229 0.26298392 0.03121896 -0.12185458 -0.08381066 -0.18210396 -0.2994332 0.15547688 0.30042946 -0.15260921 -0.15108764 0.15606089 -0.15433192 0.02746347 -0.18043859] [-0.04145756 -0.12578137 -0.00489843 -0.08871794 0.12535645 -0.13544112 -0.24285382 0.16424173 0.11010998 0.03357387 0.04003523 -0.04463525 -0.20672305 -0.15384626 0.14917772 -0.0315588 -0.11572077 0.09100223 -0.3304635 -0.00747059 -0.00310737] [-0.02675727 -0.05899341 -0.09007816 -0.07874151 0.00275524 -0.08673215 -0.13637421 0.09531567 0.03647001 0.0274155 0.08433561 -0.01373554 -0.10794039 -0.10241009 0.13272324 -0.05780311 -0.0164403 0.05420679 -0.15006481 -0.00858256 0.07314492]]
[[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [ 0.01375124 -0.12183258 0.02387197 -0.06606983 0.0128374 0.0583975 -0.15242508 0.25402677 0.08650441 -0.12012455 -0.1282458 -0.12729955 -0.18978614 0.05072499 0.13404976 -0.08693483 -0.05932525 0.0042815 -0.10118926 0.06298935 -0.11105874] [ 0.01753891 -0.26056546 -0.10591014 -0.07972372 0.14199689 0.01462964 -0.19376229 0.26298392 0.03121896 -0.12185458 -0.08381066 -0.18210396 -0.2994332 0.15547688 0.30042946 -0.15260921 -0.15108764 0.15606089 -0.15433192 0.02746347 -0.18043859] [-0.04145756 -0.12578137 -0.00489843 -0.08871794 0.12535645 -0.13544112 -0.24285382 0.16424173 0.11010998 0.03357387 0.04003523 -0.04463525 -0.20672305 -0.15384626 0.14917772 -0.0315588 -0.11572077 0.09100223 -0.3304635 -0.00747059 -0.00310737]]]
Dense Results: (3L, 8L, 10L) [[[ 0.09971005 0.10061809 0.09926669 0.10123347 0.10159823 0.09358915 0.09746949 0.10518143 0.10171816 0.09961526] [ 0.09971005 0.10061809 0.09926669 0.10123347 0.10159823 0.09358915 0.09746949 0.10518143 0.10171816 0.09961526] [ 0.09971005 0.10061809 0.09926669 0.10123347 0.10159823 0.09358915 0.09746949 0.10518143 0.10171816 0.09961526] [ 0.10025913 0.09830095 0.09733046 0.10102039 0.10041437 0.09756581 0.10080832 0.11192366 0.10017618 0.09220069] [ 0.09878089 0.10108102 0.09184121 0.09864803 0.10347681 0.09838358 0.10482841 0.1048487 0.10316593 0.09494546] [ 0.09489174 0.10166328 0.09262799 0.09878846 0.10659764 0.09606978 0.09570783 0.11349867 0.10343732 0.09671728] [ 0.09291134 0.10653186 0.09883053 0.09483069 0.1175791 0.09242333 0.08018017 0.10308622 0.1072091 0.10641769] [ 0.10223442 0.10213351 0.08504557 0.10168333 0.10628794 0.09385473 0.09465507 0.10901623 0.09734224 0.10774697]]
[[ 0.09971005 0.10061809 0.09926669 0.10123347 0.10159823 0.09358915 0.09746949 0.10518143 0.10171816 0.09961526] [ 0.09971005 0.10061809 0.09926669 0.10123347 0.10159823 0.09358915 0.09746949 0.10518143 0.10171816 0.09961526] [ 0.09971005 0.10061809 0.09926669 0.10123347 0.10159823 0.09358915 0.09746949 0.10518143 0.10171816 0.09961526] [ 0.09971005 0.10061809 0.09926669 0.10123347 0.10159823 0.09358915 0.09746949 0.10518143 0.10171816 0.09961526] [ 0.09614025 0.10190803 0.10614263 0.09728929 0.10756206 0.09338962 0.09829859 0.10613411 0.09345927 0.09967621] [ 0.10851715 0.09379967 0.10825151 0.09030457 0.10776058 0.09704588 0.09527913 0.10205451 0.10119704 0.09579002] [ 0.10465015 0.09507099 0.10042831 0.08773478 0.10443532 0.09245131 0.10091849 0.09773334 0.11714043 0.09943689] [ 0.10257854 0.0982452 0.09618285 0.09166855 0.10581338 0.09391592 0.09284189 0.10499931 0.10973047 0.10402392]]
[[ 0.09971005 0.10061809 0.09926669 0.10123347 0.10159823 0.09358915 0.09746949 0.10518143 0.10171816 0.09961526] [ 0.09971005 0.10061809 0.09926669 0.10123347 0.10159823 0.09358915 0.09746949 0.10518143 0.10171816 0.09961526] [ 0.09971005 0.10061809 0.09926669 0.10123347 0.10159823 0.09358915 0.09746949 0.10518143 0.10171816 0.09961526] [ 0.09971005 0.10061809 0.09926669 0.10123347 0.10159823 0.09358915 0.09746949 0.10518143 0.10171816 0.09961526] [ 0.09971005 0.10061809 0.09926669 0.10123347 0.10159823 0.09358915 0.09746949 0.10518143 0.10171816 0.09961526] [ 0.09614025 0.10190803 0.10614263 0.09728929 0.10756206 0.09338962 0.09829859 0.10613411 0.09345927 0.09967621] [ 0.10851715 0.09379967 0.10825151 0.09030457 0.10776058 0.09704588 0.09527913 0.10205451 0.10119704 0.09579002] [ 0.10465015 0.09507099 0.10042831 0.08773478 0.10443532 0.09245131 0.10091849 0.09773334 0.11714043 0.09943689]]] 最终结果: [[7 7 7 7 7 7 4 7] [7 7 7 7 4 0 8 8] [7 7 7 7 7 4 0 8]]
可以看到Softmax的输入是对的,但是输出似乎不对,请大神继续指导:smile:
@jxwb088047 可否简明说一下问题,是不是LSTM的输出是被mask了,然后LSTM后面你要接一个TimeDistributed的Dense然后再softmax,你期望输出是什么呢?
@MoyanZitto LSTM的输出是成功mask了,softmax后的输出也期望是mask的,但似乎不成功,求大神指点:smile:
@jxwb088047 softmax的输出值时一个概率啊,怎么能被mask呢?
你好你的问题解决了吗,我遇到了类似的问题,网络结构是mask+lstm+TimeDistributed(Dense), lstm层之后还是有输出的,但是TimeDistributed(Dense)之后的输入就全为0,求大神解答,结构如下
input=Input(shape=(max_len,feat_dim),name='input_layer')
input_layer=input
mask=Masking(mask_value=0,name='mask_layer')(input_layer)
lstm_dt=LSTM(200,return_sequences=True,name='lstm_dt')(mask)
fc_dt=TimeDistributed(Dense(3,activation='relu'),name='fc_dt')(lstm_dt)