machine-learning-book
machine-learning-book copied to clipboard
Confusing matrix operations in Chpt. 16
I find the matrix operations in Chpt. 16 confusing. For example, instead of:
keys = U_key.matmul(embedded_sentence.T).T
values = U_value.matmul(embedded_sentence.T).T
omega_23 = query_2.dot(keys[2])
attention_weights_2 = F.softmax(omega_2 / d**0.5, dim=0)
It's clearer to do:
queries = embedded_sentence @ U_query
keys = embedded_sentence @ U_key
values = embedded_sentence @ U_value
omega = queries @ keys.T
attention_weights = F.softmax(omega / d**0.5, dim=0)
W/ multi-head attention, instead of:
stacked_inputs = embedded_sentence.T.repeat(8, 1, 1)
multihead_keys = torch.bmm(multihead_U_key, stacked_inputs)
multihead_keys = multihead_keys.permute(0, 2, 1)
# Eventually giving up...
multihead_z_2 = torch.rand(8, 16)
we can just do:
multihead_queries= embedded_sentence @ multihead_U_query
multihead_keys = embedded_sentence @ multihead_U_key
multihead_values = embedded_sentence @ multihead_U_vvalue
multihead_weights = F.softmax(multihead_queries @ multihead_keys.transpose(1, 2) / d**0.5, dim=1)
multihead_z = multihead_weights @ multhead_values
which makes it clear that the multihead case is analogous to the single-head case.
Thanks for the comment, and I 100% agree. Not sure why I made it unnecessarily complicated there. In my other book (Build an LLM from Scratch), I am using the more legible version similar to what you suggest: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/01_main-chapter-code/ch03.ipynb