snowfall
snowfall copied to clipboard
Trace frame-level scores using lattice
Hi, team!
I have encountered a problem with k2 in my code. Below is the description of this problem.
For a nnet_output
with shape [B, T, D]
, I am trying to calculate the scores on a graph (MMI numerator or denominator) with any prefix segment of nnet_output
, namely nnet_output[:, :t, :]
, where t
is any index smaller than T (the total length in time axis).
Currently, I implement it by a loop. But this leads to much computation. My code is below
graph, _ = self.graph_compiler.compile(texts, self.P, replicate_den=True)
T = x.size()[1]
tot_scores = []
for t in range(T, 0, -1):
supervision = torch.Tensor([[0, 0, t]]).to(torch.int32) # [idx, start, length]
dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision)
lats = k2.intersect_dense(graph, dense_fsa_vec, output_beam=10.0)
frame_score = lats.get_tot_scores(log_semiring=True, use_double_scores=True)
tot_scores.append(frame_score)
tot_scores = torch.cat(tot_scores)
Could these scores be calculated by parsing the lats
obtained from the whole nnet_output
, which means we can calculate them with only one k2.intersect_dense
? Approximation is also ok for me.
Thanks for your help ! :)
Additional information:
- we do NOT need this process differentiable.
- It seems that the scores of each state and arc are accessible. Could we solve this problem by those scores on lattices?
Please refer to the help doc of k2.intersect_dense
:
https://k2-fsa.github.io/k2/python_api/api.html#k2.intersect_dense
There are two extra optional arguments:
def intersect_dense(a_fsas: Fsa,
b_fsas: DenseFsaVec,
output_beam: float,
a_to_b_map: Optional[torch.Tensor] = None,
seqframe_idx_name: Optional[str] = None,
frame_idx_name: Optional[str] = None) -> Fsa:
You can use
lats = k2.intersect_dense(graph, dense_fsa_vec, output_beam=10.0, seqframe_idx_name='seqframe', frame_idx_name='frame')
After the above call, the resulting lats
has two extra 1-D tensor attributes: seqframe
and frame
.
That is, you can access them using
lats.seqframe
lats.frame
In your case, the supervision contains only one utterance, so lats.seqframe
and lats.frame
should be the same.
lats.frame
contains values in the range from 0
to T
and it contains as many entries as lats.num_arcs
.
You can invoke k2.intersect_dense
only once by feeding T
frames and then call lats.get_forward_scores
to get the forward scores of each state. After this, you can use lats.frame
to identify which states corresponding
to the t
-th frame, and then use log_sum_exp
to sum the scores of these states.
[EDITED]: The whole process is also differentiable.
Thanks for the reply
It seems that lats.frame
represents the frame index of each arc. But I'm not sure we can recover the mapping from state to frame by lats.frame
only.
Previously, I did not notice the API lats.frame
. However, I have recovered this mapping in another way (see code below)
def trace_frame(lats):
arcs = lats[0].as_dict()['arcs']
frame2state = []
prev_buf, cur_buf = [0], []
for arc in arcs:
f, t, _, _ = arc
f, t = int(f), int(t)
if f in prev_buf:
if not t in cur_buf:
cur_buf.append(t)
else:
frame2state.append(prev_buf)
prev_buf = cur_buf
cur_buf = [t]
frame2state.append(prev_buf) # last frame
frame2state.append([t]) # final state
return frame2state
After that, I try to compute the frame-level score by frame2state
, lats.get_forward_scores
and then use log_sum_exp
, but there is a large gap between the results obtained in this way and the way with a loop (as I presented first). Below is the code
_, den = self.graph_compiler.compile(texts, self.P, replicate_den=True)
T = x.size()[1]
print(T)
den_scores = []
for t in range(T, 0, -1):
supervision = torch.Tensor([[0, 0, t]]).to(torch.int32) # [idx, start, length]
dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision)
den_lats = k2.intersect_dense(den, dense_fsa_vec, output_beam=10.0)
den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
den_scores.append(den_tot_scores)
den_scores = torch.cat(den_scores).unsqueeze(0) # [T] -> [B, T]
print("den score computed from previous version: ", den_scores)
# new implementation
supervision = torch.Tensor([[0, 0, T]]).to(torch.int32)
dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision)
den_lats = k2.intersect_dense(den, dense_fsa_vec, output_beam=10.0,\
seqframe_idx_name='seqframe', frame_idx_name='frame')
frame2state = trace_frame(den_lats)
den_forward_scores = den_lats.get_forward_scores(log_semiring=True, use_double_scores=True)
assert len(frame2state) == T + 2 # extra start and end state
den_scores_new = []
for t in range(T, 0, -1):
states = frame2state[t]
den_score = torch.logsumexp(den_forward_scores[states], dim=0)
den_scores_new.append(den_score)
den_scores_new = torch.stack(den_scores_new, dim=-1)
print("den score computed from new version: ", den_scores_new)
print("diffence of the two version: ", den_scores_new - den_scores)
And the results of the two versions cannot match.
den score computed from previous version: tensor([[783.3993, 781.7219, 780.0333, 778.3439, 776.5803, 774.7105, 772.7158,
770.5528, 768.1100, 762.4893, 754.7248, 746.1509, 736.5164, 726.7509,
716.2331, 706.6452, 697.4263, 686.9293, 681.2793, 674.0765, 664.6529,
657.2257, 650.4535, 643.3796, 634.3289, 628.0670, 619.4294, 609.9516,
600.0689, 594.3592, 590.8794, 585.2605, 581.2088, 578.4872, 572.3086,
571.1170, 566.4870, 558.6632, 552.9537, 545.4799, 541.9318, 530.6125,
521.3177, 511.8524, 502.7617, 492.8600, 487.0916, 479.4112, 471.2255,
462.2586, 454.3531, 442.9971, 436.3475, 428.7141, 420.9272, 413.6736,
406.7801, 397.0022, 391.5931, 386.9415, 379.1681, 370.1025, 361.0875,
351.7468, 341.6609, 333.3474, 325.8754, 318.5557, 309.9469, 303.1736,
296.7173, 286.5587, 280.5947, 273.7012, 266.3182, 258.1839, 250.6710,
242.8872, 236.6474, 229.7769, 221.9658, 213.3599, 204.9579, 192.9310,
186.8552, 179.3116, 170.0203, 160.0337, 150.8213, 142.6257, 134.6330,
127.6783, 116.1433, 103.3247, 91.5196, 80.5909, 69.7985, 59.0771,
48.1386, 36.4054, 23.1133, 11.1909, -2.9392]],
dtype=torch.float64)
den score computed from new version: tensor([783.3993, 781.7252, 780.0382, 778.3510, 776.5909, 774.7267, 772.7416,
770.5969, 768.1950, 762.6968, 755.2699, 747.3608, 738.8629, 730.4765,
725.2539, 716.5815, 707.1374, 698.3506, 691.5351, 679.1446, 670.6063,
663.6814, 656.9236, 651.3605, 647.0456, 640.6885, 626.9087, 617.5355,
610.5875, 604.0881, 598.1472, 590.7349, 588.7432, 585.2757, 583.0598,
581.2615, 570.5313, 566.4218, 561.7389, 557.9919, 554.6887, 543.3505,
529.1934, 521.2846, 512.1856, 504.0500, 500.0110, 490.2833, 477.4947,
471.4613, 463.3775, 455.1872, 449.6866, 438.6009, 426.2327, 421.9648,
415.6079, 408.7287, 403.9518, 395.7307, 387.3507, 378.8256, 370.0204,
361.4262, 355.0736, 346.0511, 331.8522, 322.9216, 317.4937, 311.1927,
304.5424, 297.9715, 292.8139, 282.6298, 272.3836, 264.5364, 257.1948,
250.8738, 245.0846, 235.5465, 227.8981, 221.6154, 213.5615, 205.2312,
199.0696, 187.7953, 177.1561, 169.5724, 160.6012, 151.7466, 146.6541,
140.5165, 129.1156, 116.3533, 104.6344, 93.8204, 83.1367, 72.5044,
61.6243, 49.9238, 36.6502, 24.7326, 10.5949], dtype=torch.float64)
diffence of the two version: tensor([[ 0.0000, 0.0034, 0.0049, 0.0071, 0.0106, 0.0162,
0.0258, 0.0441, 0.0850, 0.2075, 0.5451, 1.2099,
2.3464, 3.7256, 9.0209, 9.9363, 9.7111, 11.4214,
10.2558, 5.0681, 5.9534, 6.4558, 6.4701, 7.9809,
12.7167, 12.6216, 7.4792, 7.5838, 10.5185, 9.7289,
7.2677, 5.4744, 7.5344, 6.7885, 10.7512, 10.1445,
4.0443, 7.7586, 8.7853, 12.5120, 12.7569, 12.7380,
7.8757, 9.4322, 9.4238, 11.1901, 12.9194, 10.8722,
6.2692, 9.2026, 9.0244, 12.1901, 13.3391, 9.8868,
5.3055, 8.2912, 8.8277, 11.7265, 12.3586, 8.7891,
8.1826, 8.7230, 8.9329, 9.6794, 13.4127, 12.7037,
5.9769, 4.3660, 7.5468, 8.0192, 7.8250, 11.4128,
12.2193, 8.9286, 6.0653, 6.3525, 6.5239, 7.9867,
8.4372, 5.7696, 5.9323, 8.2554, 8.6036, 12.3002,
12.2144, 8.4837, 7.1358, 9.5387, 9.7798, 9.1209,
12.0211, 12.8382, 12.9724, 13.0286, 13.1148, 13.2295,
13.3382, 13.4273, 13.4856, 13.5184, 13.5370, 13.5417,
13.5341]], dtype=torch.float64)
Currently, I don't know which version is correct.
The motivation of this is to compute the probability P(W|O_{1:t})
according to the definition of LF-MMI but only use the first several frames.
arcs = lats.arcs.values()[:, :2]
# arcs is a 2-D torch.int32 tensor
for idx, (src, dst) in enumerate(arcs.tolist()):
# note src is not used and you can replace it with an underscore _
frame_idx = lats.frame[idx]
# now you konw the state `dst` belongs to the frame `frame_idx`
# You can add the forward_score of this state to a list corresponding to the frame `frame_idx`
#
# Caution: You have to avoid adding `dst` state multiple times
# At this point, you know the states corresponding to each frame, you can use `log-sum-exp` to combine them.
But I'm not sure we can recover the mapping from state to frame by lats.frame only.
As I posted above, you can iterate over the arcs; for each arc, you can get its frame_idx and dest_state.
This dest_state belongs to the frame frame_idx
.
If you have multiple utterances, then you have to use seqframe_idx
.
Note: I have re-edited the demo code.
Previously, I did not notice the API lats.frame. However, I have recovered this mapping in another way (see code below)
Please use a small lats
to verify that your code is correct.
(You can print the resulting lats
, its frame
, its states, the return value frame2state
to check the correctness of your code)
~Looks like your frame2state
is a 1-d list, which is not correct.~
A frame can correspond to multiple states, while a state belongs to only one frame.
You can note down which state belongs to which frame. For example
frame_0_states = [1, 2, 3]
frame_1_states = [4, 5, 6, 7, 8, 9]
....
You can get the total scores for frame 0 using
forward_scores = lats.get_forward_scores(use_double_scores=True, log_semiring=True)
frame_0_tot_scores = forward_scores[frame_0_states].exp().sum().log()
Note: For the last frame T-1
, you have to consider the scores on the arcs entering the final state if those scores are not zero.
@csukuangfj
As suggested, I restrict the nnet_output
to 3 frames and checked the frame2state
. I suppose my frame2state
is correct.
den score computed from previous version: tensor([[23.1133, 11.1909, -2.9392]], dtype=torch.float64)
arcs: tensor([[ 0, 1, 0, 1083890428],
[ 0, 2, 147, 1093236685],
[ 1, 13, 147, 1093883163],
[ 2, 3, 21, -1083650674],
[ 2, 4, 35, 1042883936],
[ 2, 5, 43, -1103412784],
[ 2, 6, 48, -1091424240],
[ 2, 7, 111, -1081929866],
[ 2, 8, 112, -1090075914],
[ 2, 9, 114, 1050298232],
[ 2, 10, 115, -1070879172],
[ 2, 11, 140, -1089980508],
[ 2, 12, 141, -1078741436],
[ 2, 13, 147, 1096957993],
[ 2, 14, 153, -1097704528],
[ 2, 15, 189, -1077760680],
[ 3, 16, 1, 1081681798],
[ 4, 16, 1, 1079942965],
[ 5, 16, 1, 1078807868],
[ 6, 16, 1, 1082282282],
[ 7, 16, 1, 1080307114],
[ 8, 16, 1, 1080573812],
[ 9, 16, 1, 1082305666],
[ 10, 16, 1, 1085392658],
[ 11, 16, 1, 1081433908],
[ 12, 16, 1, 1082813821],
[ 13, 18, 0, 1092706735],
[ 13, 16, 1, -1058648121],
[ 13, 17, 147, 1094443926],
[ 14, 16, 1, 1081785469],
[ 15, 16, 1, 1082840980],
[ 16, 19, -1, -1221591040],
[ 17, 19, -1, -1051153366],
[ 18, 19, -1, -1051153366]],
dtype=torch.int32)
lats.frame: tensor([0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 3, 3, 3], dtype=torch.int32)
frame2state: [[0], [1, 2], [13, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15], [16, 18, 17], [19]]
den score computed from new version: tensor([36.6502, 24.7326, 10.5949], dtype=torch.float64)
diffence of the two version: tensor([[13.5370, 13.5417, 13.5341]], dtype=torch.float64)
Note: For the last frame T-1, you have to consider the scores on the arcs entering the final state if those scores are not zero.
I'm also worried about the scores on the final states:
If nnet_output[:, :t, :]
is used in intersect_dense
, states that belong to t-th frame should be considered as the final states and final scores should be considered. (In k2 you implement it by an additional final state with these final scores on some additional arcs). This is what happens in my original implementation.
However, by tracing the lattice, states that belong to all frames (except the last frame) are not considered as final states and don't add the extra final scores. That may cause a difference in the scores.
frame2state: [[0], [1, 2], [13, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15], [16, 18, 17], [19]]
[0]
is the start state and does not belong to any frame. Looks like there is an offset-by-one error.
For arc i
, we get its dest_state
(not src_state
) and frame_idx = lats.frame[i]
, then dest_state
belongs to frame frame_idx
.
I have dumped the lattice as below.
if the frame is indexed from 1, i suppose we should do logsumexp
on groups below.
frame_1_states = [1, 2]
frame_2_states = [13, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15]
frame_3_states = [16, 18, 17]
[0] is the start state and does not belong to any frame. Looks like there is an offset-by-one error.
In my frame2state
, the start and end states are the first and last entries respectively. Note the loop for t in range(T, 0, -1)
with T=3 means only t=3, 2, 1
would be called, which would ignore both [0]
and [19]
For arc i, we get its dest_state (not src_state) and frame_idx = lats.frame[i], then dest_state belongs to frame frame_idx.
Follow this advice, I have also revised my trace_frame
function as below
def trace_frame(lats):
arcs = lats.arcs.values()[:, :2]
T = max(lats.frame).item()
frame2state = [[] for _ in range(T+1)]
for idx, (_, dst) in enumerate(arcs.tolist()):
frame_idx = lats.frame[idx]
if dst not in frame2state[frame_idx]:
frame2state[frame_idx].append(dst)
print("frame2state: ", frame2state)
return frame2state
and the output of this is like below. It only ignore the start state [0]
while other entries are the same
frame2state: [[1, 2], [13, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15], [16, 18, 17], [19]]
Use the new frame2state
, the difference between the two versions are unchanged:
den score computed from previous version: tensor([[23.1133, 11.1909, -2.9392]], dtype=torch.float64)
frame2state: [[1, 2], [13, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15], [16, 18, 17], [19]]
logsumexp on [16, 18, 17] is 36.650240146578824
logsumexp on [13, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15] is 24.732586652356957
logsumexp on [1, 2] is 10.594917989250451
den score computed from new version: tensor([36.6502, 24.7326, 10.5949], dtype=torch.float64)
diffence of the two version: tensor([[13.5370, 13.5417, 13.5341]], dtype=torch.float64)
Note:
I have changed assert len(frame2state) == T + 2
to assert len(frame2state) == T + 1
since [0]
does not exits now
I have changed states = frame2state[t]
to states = frame2state[t-1]
also because [0]
does not exits now
The log above presents all states
in each logsumexp
. Now I suppose we don't have the ``offset-by-one error
Note: For the last frame T-1, you have to consider the scores on the arcs entering the final state if those scores are not zero.
As shown in the lattice above, arc 17->19
and 18->19
have scores -13.54
. So the scores on state 17
and 19
shoude be revised before logsumexp
. (So far i don't do it)
I'm curious whether a similar modification should be done for states that belong to other frames. As in my privous implementation, each frame would be considered as the final frame once. If needed, how to do this?
Thanks :)
I just created a colab notebook (see https://colab.research.google.com/drive/1iyc_q8aHuKd-RZxtYv9EqfyjB2QZDSOx?usp=sharing) to verify the idea.
The following code you posted:
graph, _ = self.graph_compiler.compile(texts, self.P, replicate_den=True)
T = x.size()[1]
tot_scores = []
for t in range(T, 0, -1):
supervision = torch.Tensor([[0, 0, t]]).to(torch.int32) # [idx, start, length]
dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision)
lats = k2.intersect_dense(graph, dense_fsa_vec, output_beam=10.0)
frame_score = lats.get_tot_scores(log_semiring=True, use_double_scores=True)
tot_scores.append(frame_score)
tot_scores = torch.cat(tot_scores)
is not equivalent to the one where the tot_scores
is computed from the whole lattice (when you feed T frames at once)
I just created a colab notebook (see https://colab.research.google.com/drive/1iyc_q8aHuKd-RZxtYv9EqfyjB2QZDSOx?usp=sharing) to verify the idea.
The following code you posted:
graph, _ = self.graph_compiler.compile(texts, self.P, replicate_den=True) T = x.size()[1] tot_scores = [] for t in range(T, 0, -1): supervision = torch.Tensor([[0, 0, t]]).to(torch.int32) # [idx, start, length] dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision) lats = k2.intersect_dense(graph, dense_fsa_vec, output_beam=10.0) frame_score = lats.get_tot_scores(log_semiring=True, use_double_scores=True) tot_scores.append(frame_score) tot_scores = torch.cat(tot_scores)
is not equivalent to the one where the
tot_scores
is computed from the whole lattice (when you feed T frames at once)
My observation is the same: the tot_scores
computed by the two methods are different.
Is there anything wrong with my code? or any suggestions on this mismatch?
For the following dense_fsa_vec
(you can find the code for all the below comments in the above colab notebook)
with the following decoding graph:
Feed T frames at once
If you feed T
frames at once, the resulting lats
is
lats.frame
and lats.get_forward_scores
are
The tot_scores
for frame 0 is
(Note: State 1 and state 2 belong to frame 0, so torch.tensor([1, 2])
is used above)
Feed frames separately
(1) If you feed only 1 frame, i.e., frame 0, the resulting lats
is
Caution: The lats
is different from the one when you feed 3 frames at once, so the tot_scores
computed
from this lats
is different from the one computed from the whole lattice.
(2) If you feed only 2 frames, i.e., frame 0 and frame 1, the resulting lats
is
Caution: The states for frame 0 are the same as the ones contained in the whole lattice. However, the states for frame 1
are different from that contained in the whole lattice. So the tot_scores
computed for frame 1 using this lattice differs
from the one computed from the whole lattice.
Note: tot_scores
for frame 0 computed from this lats
is the same as the one computed from the whole lattice.
So in the initial version of your code, I would recommend you to feed t+1
frames to compute the tot_scores
for frame t
.
If you feed only t
frames, then the result is not correct.
If I don't misunderstand:
Given the lattice dumped by t+1
frames, only the first t
frames of this lattice is a sub-graph of the whole lattice while states that belong to t+1
-th frame should be ignored.
If this is true, I should compute the scores on the first T-1 frames by the logsumexp
method while scores for T-th frame are obtained from the scores on the final state.
Also, I should never call lats.get_tot_scores
in my implementation.
I'm still curious about:
As defined by LF-MMI, the probability P(W|O_{1:t}) is the difference between the numerator and denominator scores. If I use only the first t
frames to get the lattice and then use lats.get_tot_scores
for both the numerator and denominator to get these scores, is it wrong in this scenario?
In other words, would the lats.get_tot_scores
lead to the wrong number in some concept for P(W|O_{1:t}) if t is not equal to T?
In other words, would the lats.get_tot_scores lead to the wrong number in some concept for P(W|O_{1:t}) if t is not equal to T?
Maybe @danpovey has more to say about it.
Given the lattice dumped by t+1 frames, only the first t frames of this lattice is a sub-graph of the whole lattice while states that belong to t+1-th frame should be ignored.
I was explaining why the tot_scores
obtained by feeding frames separately is different from the one computed
from the whole lattice. If you want to get identical results, you have to feed t+1
frames to compute the tot_scores
for the t
-th frame.
Thanks!
As @csukuangfj advised, I have rewritten the two methods. currently, the results can match. The key point is that:
- for the last frame, we use the last element of
forward_scores
(scores on the final state) - for other frames, we feed t+1 frames to obtain the scores on
t
-th frame
def trace_lattice(lats):
arcs = lats.arcs.values()[:, :2]
T = max(lats.frame).item()
frame2state = [[] for _ in range(T+1)]
for idx, (_, dst) in enumerate(arcs.tolist()):
frame_idx = lats.frame[idx]
if dst not in frame2state[frame_idx]:
frame2state[frame_idx].append(dst)
return frame2state
def compute_frame_level_scores(graph, nnet_output):
T = nnet_output.size()[1]
# dump lattice
supervision = torch.Tensor([[0, 0, T]]).to(torch.int32)
dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision)
lats = k2.intersect_dense(graph, dense_fsa_vec, output_beam=10.0,\
seqframe_idx_name='seqframe', frame_idx_name='frame')
# compute frame-level scores
forward_scores = lats.get_forward_scores(True, True)
frame2states = trace_lattice(lats)
assert len(frame2states) == T + 1 # extra final state
tot_scores = []
for t in range(T, 0, -1):
# scores for the last frame
if t == T:
tot_scores.append(forward_scores[-1])
# scores for other frames
else:
states = frame2states[t-1]
frame_score = torch.logsumexp(forward_scores[states], dim=-1)
tot_scores.append(frame_score)
print(f"scores computed from states {states} is {frame_score}")
tot_scores = torch.stack(tot_scores, dim=0)
return tot_scores
def compute_frame_level_scores_loop(graph, nnet_output):
T = nnet_output.size()[1]
tot_scores = []
for t in range(T, 0, -1):
# feed one more frame if it's not the last frame
# so the states in first t frames is identical to
# the those in whole lattice
t_ = t if t == T else t + 1
supervision = torch.Tensor([[0, 0, t_]]).to(torch.int32)
dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision)
lats = k2.intersect_dense(graph, dense_fsa_vec, output_beam=10.0,\
seqframe_idx_name='seqframe', frame_idx_name='frame')
forward_scores = lats.get_forward_scores(True, True)
frame2states = trace_lattice(lats)
if t == T:
tot_scores.append(forward_scores[-1])
else:
assert len(frame2states) == t + 2
states = frame2states[t-1]
frame_score = torch.logsumexp(forward_scores[states], dim=-1)
tot_scores.append(frame_score)
print(f"scores computed from states {states} is {frame_score}")
tot_scores = torch.stack(tot_scores, dim=0)
return tot_scores
if __name__ == '__main__':
nnet_output = torch.tensor(
[
[0.1, 0.22, 0.28, 0.4],
[0.1, 0.13, 0.07, 0.7],
[0.6, 0.2, 0.05, 0.15],
], dtype=torch.float32
).unsqueeze(0)
nnet_output = torch.nn.functional.log_softmax(nnet_output, -1)
graph = k2.ctc_graph([[1]])
scores = compute_frame_level_scores(graph, nnet_output)
print("Scores computed by new version: ", scores)
scores = compute_frame_level_scores_loop(graph, nnet_output)
print("Scores computed by original version: ", scores)
The results:
scores computed from states [3, 4, 5] is -1.652332052730104
scores computed from states [1, 2] is -0.787191693952512
Scores computed by new version: tensor([-2.4776, -1.6523, -0.7872], dtype=torch.float64)
scores computed from states [3, 4, 5] is -1.652332052730104
scores computed from states [1, 2] is -0.787191693952512
Scores computed by original version: tensor([-2.4776, -1.6523, -0.7872], dtype=torch.float64)
Maybe @danpovey has more to say about it. Looking forward to dan's reply :)