Extension of SAN for complex-valued input
Hi Blaz,
First of, thank you for the open-access framework. I tested already some of the architecture on test data and produces great results. I was wondering whether I could pick your brain regarding possible extending the SAN to accept complex-valued inputs seeing as pytorch already works in complex-valued data.
From what I read on some papers (https://arxiv.org/abs/1802.08026) and saw on other implementations torchlex we can simply perform the activation function on split (real and imag) parts of the input. Regarding the softmax application, we can just take the magnitude ( Re^2(input) + Im^2(input) ) as the input of the softmax. Following this split-wise reasoning, the attention application remains the same as we can just perform identical operations and we can then
I tried extending some of your code but I'm stuck on the foward attention step and I was wondering whether you could give me a hand (binary classification so the labels are real numbers but I'm not sure whether we should use BCE or MSE loss after the abs at the output):
Edited code
def complex_selu(input):
return F.selu(input.real).type(torch.complex128)+1j*F.selu(input.imag).type(torch.complex128)
def apply_complex(fr, fi, input, dtype = torch.complex128):
#print((fr(input.real)-fi(input.imag)).type(dtype))
return (fr(input.real)-fi(input.imag)).type(dtype) \
+ 1j*(fr(input.imag)+fi(input.real)).type(dtype)
class ComplexSeLU(nn.Module):
def forward(self,input):
return complex_selu(input)
class ComplexLinear(nn.Module):
def __init__(self, in_features, out_features):
super(ComplexLinear, self).__init__()
self.fc_r = nn.Linear(in_features, out_features, dtype=torch.float64)
self.fc_i = nn.Linear(in_features, out_features, dtype=torch.float64)
def forward(self, input):
#print('complex applied')
return apply_complex(self.fc_r, self.fc_i, input)
class E2EDatasetLoader(Dataset):
def __init__(self, features, targets=None): # , transform=None
features = sparse.csr_matrix(features)
self.features = features.tocsr()
if targets is not None:
self.targets = targets # .tocsr()
else:
self.targets = targets
def __len__(self):
return self.features.shape[0]
def __getitem__(self, index):
instance = torch.from_numpy(self.features[index, :].todense())
if self.targets is not None:
target = torch.from_numpy(np.array(self.targets[index]))
else:
target = None
return instance, target
def to_one_hot(lbx):
enc = OneHotEncoder(handle_unknown='ignore')
return enc.fit_transform(lbx.reshape(-1, 1))
class SANNetwork(nn.Module):
def __init__(self, input_size, num_classes, hidden_layer_size, dropout=0.02, num_heads=2, device="cuda"):
super(SANNetwork, self).__init__()
self.fc1 = ComplexLinear(input_size, input_size)#nn.Linear(input_size, input_size)
self.device = device
self.softmax = nn.Softmax(dim=1)
self.softmax2 = nn.Softmax(dim=0)
self.softmax3 = nn.Softmax(dim=-1) # the last dim indicates the feature dim
self.activation = ComplexSeLU()#nn.SELU()
self.num_heads = num_heads
self.sigmoid = nn.Sigmoid()
self.dropout = nn.Dropout(dropout)
self.fc2 = ComplexLinear(input_size, hidden_layer_size)#nn.Linear(input_size, hidden_layer_size)#
self.fc3 = ComplexLinear(hidden_layer_size, num_classes)#nn.Linear(hidden_layer_size, num_classes)#
self.multi_head = nn.ModuleList([ComplexLinear(input_size, input_size) for k in range(num_heads)])
def forward_attention(self, input_space, return_softmax=False):
placeholder = torch.zeros(input_space.shape).to(self.device)
for k in range(len(self.multi_head)):
if return_softmax:
attended_matrix = (self.multi_head[k](input_space))
else:
attended_matrix = self.softmax3(abs(self.multi_head[k](input_space))) * input_space
placeholder = torch.add(placeholder,attended_matrix)
placeholder /= len(self.multi_head)
out = placeholder
if return_softmax:
out = self.softmax(out)
return out
def get_mean_attention_weights(self):
activated_weight_matrices = []
for head in self.multi_head:
wm = head.weight.data
diagonal_els = torch.diag(wm)
activated_diagonal = self.softmax2(abs(diagonal_els))
activated_weight_matrices.append(activated_diagonal)
output_mean = torch.mean(torch.stack(activated_weight_matrices, axis=0), axis=0)
return output_mean
def forward(self, x):
# attend and aggregate
out = self.forward_attention(x)
# dense hidden (l1 in the paper)
# out = x
out = self.fc2(out)
#out = self.dropout(out)
out = self.activation(out)
# dense hidden (l2 in the paper, output)
out = self.fc3(out)
out = self.sigmoid(abs(out))
return out
def get_attention(self, x):
return self.forward_attention(x, return_softmax=True)
def get_softmax_hadamand_layer(self):
return self.get_mean_attention_weights()
class SAN:
def __init__(self, batch_size=32, num_epochs=32, learning_rate=0.001, stopping_crit=10, hidden_layer_size=64,num_heads=1,
dropout=0.2): # , num_head=1
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# self.loss = torch.nn.CrossEntropyLoss()
self.loss = torch.nn.BCELoss()#This is for binary case
#self.loss = torch.nn.NLLLoss()#This is for binary complex value case
self.dropout = dropout
self.num_heads = num_heads
self.batch_size = batch_size
self.stopping_crit = stopping_crit
self.num_epochs = num_epochs
self.hidden_layer_size = hidden_layer_size
self.learning_rate = learning_rate
self.model = None
self.optimizer = None
self.num_params = None
self.config = "{}-{}-{}-{}-{}".format(num_heads, learning_rate, hidden_layer_size, dropout, num_epochs)
def fit(self, features, labels): # , onehot=False
nun = len(np.unique(labels))
one_hot_labels = []
for j in range(len(labels)):
lvec = np.zeros(nun)
lj = labels[j]
lvec[lj] = 1
one_hot_labels.append(lvec)
one_hot_labels = np.matrix(one_hot_labels)
logging.info("Found {} unique labels.".format(nun))
train_dataset = E2EDatasetLoader(features, one_hot_labels)
dataloader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=3)
stopping_iteration = 0
current_loss = np.inf
self.model = SANNetwork(features.shape[1], num_classes=nun, hidden_layer_size=self.hidden_layer_size, num_heads = self.num_heads,
dropout=self.dropout, device=self.device).to(self.device)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
self.num_params = sum(p.numel() for p in self.model.parameters())
logging.info("Number of parameters {}".format(self.num_params))
logging.info("Starting training for {} epochs".format(self.num_epochs))
print(self.model)
loss_vec = []
for epoch in range(self.num_epochs):
logging.info("epoch {}".format(epoch))
if stopping_iteration > self.stopping_crit:
logging.info("Stopping reached!")
break
losses_per_batch = []
self.model.train()
for i, (features, labels) in enumerate(dataloader):
features = features.to(self.device)
labels = labels.to(self.device)
outputs = self.model(features)
loss = self.loss(outputs, labels)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
losses_per_batch.append(float(loss))
mean_loss = np.mean(losses_per_batch)
loss_vec.append(mean_loss)
if mean_loss < current_loss:
current_loss = mean_loss
stopping_iteration = 0
else:
stopping_iteration += 1
logging.info("------> mean loss per batch {}".format(mean_loss))
return loss_vec
def predict(self, features, return_proba=False):
test_dataset = E2EDatasetLoader(features, None)
predictions = []
with torch.no_grad():
for features, _ in test_dataset:
self.model.eval()
features = features.float().to(self.device)
representation = self.model(features)
pred = representation.detach().cpu().numpy()[0]
predictions.append(pred)
if not return_proba:
a = [np.argmax(a_) for a_ in predictions] # assumes 0 is 0
return np.array(a).flatten()
else:
a = [a_ for a_ in predictions]
return a
def predict_proba(self, features):
test_dataset = E2EDatasetLoader(features, None)
predictions = []
self.model.eval()
with torch.no_grad():
for features, _ in test_dataset:
features = features.float().to(self.device)
representation = self.model.forward(features)
pred = representation.detach().cpu().numpy()[0]
predictions.append(pred)
a = [a_[1] for a_ in predictions]
return np.array(a).flatten()
def get_mean_attention_weights(self):
return self.model.get_mean_attention_weights().detach().cpu().numpy()
def get_instance_attention(self, instance_space):
if "scipy" in str(type(instance_space)):
instance_space = instance_space.todense()
instance_space = torch.from_numpy(instance_space).float().to(self.device)
return self.model.get_attention(instance_space).detach().cpu().numpy()
Hi! Glad you were able to build on top, great work so far. Based on e.g., https://arxiv.org/pdf/1705.09792.pdf paper, cross-entropy loss is for sure one option. Alternatively, one would imagine loss has to be holomorphic (complex-differentiable), not sure if there are any apparent widely used candidates for this. I'd try with the simples one first (cross entropy-based), just to get some feel for the behavior - also, perhaps this was seen before by @Petkomat too
Hi Blaz, thanks for the fast reply!
So I remained working on the complex-valued SAN and, thus, I edited the intial response with the current working version. So basically, I have a complex-valued tabular data X and the respective (binary) real class labels y
First, regarding the loss function as I would have though that it would need to deal with complex-valued output but, considering the (eq.36) alternative I guess the output at the output: out = self.sigmoid(abs(out)) is real so I guess we can skim through the complex-valued loss and just use the BCE. Does this seem right?
Now, regarding the attention, I wrote the function below following (considering k>1). Again, following the same principle for the softmax, I can just compute the abs of the multiplication of the complex-valued weights and the input. After the posterior product by the input space we still get a complex-valued output at out = self.forward_attention(x). What do you think?
def forward_attention(self, input_space, return_softmax=False):
placeholder = torch.zeros(input_space.shape).to(self.device)
for k in range(len(self.multi_head)):
if return_softmax:
attended_matrix = (self.multi_head[k](input_space))
else:
attended_matrix = self.softmax3(abs(self.multi_head[k](input_space))) * input_space
placeholder = torch.add(placeholder,attended_matrix)
placeholder /= len(self.multi_head)
out = placeholder
if return_softmax:
out = self.softmax(out)
return out
Thanks in advance for the help! If you want we can discuss it further
BCE seems plausible in this case, I'd suggest you just test it out. R.e. attention part - can see how this might work, not sure if not too simplistic; would test that out too (you can add examples eve to this repo if generic enough)
Alright, that seems fun. I'll test it out on some benchmark datasets and get back to you.
fingers crossed
On Wed, 27 Sep 2023, 14:00 Tomás Soares da Costa, @.***> wrote:
Alright, that seems fun. I'll test it out on some benchmark datasets and get back to you.
— Reply to this email directly, view it on GitHub https://github.com/SkBlaz/san/issues/8#issuecomment-1737254643, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACMSERBPG6SA35I4HTWMJ4DX4QIOXANCNFSM6AAAAAA5H4RDJ4 . You are receiving this because you commented.Message ID: @.***>
Apologies for the delay, only this week that I managed to dedicate some time onto the problem. Below are some results for the breast cancer dataset (just cast into the complex domain), with parameters equal to your simple benchmark example. I will further evaluate the consideration of other datasets but some initial comments to ensure I made no (initial mistakes):
- Since we are evaluating the attention, I guess we should just look at the global and local attention after training?
- Evidently the loss function is very different as there are noteworthy differences in the network (as we utilized split A function, the number of parameters increased (?)). The local attention seems quite sparse, is this expected (vertically)?
- Global attention appears different, which is expected seeing as they are two different networks?
- When I utilize my own electromagnetic datasets I also get diverging results but, for its context, they are inlined with the expected. I dont have much sensisivity for this dataset so I cant really tell if this is a colinear outcome.
- Any suggestions or helpfull tips I could check? 😅
(Upper Image is standard Real SAN and lower is complex equivalent, see titles)
Global Attention:
Local Attention:
Loss:
Loss:
Hi! Very neat work. Behavior is very different indeed, worth checking the overfitting perhaps. Local focus makes some sense given your formulation, yet it's hard to say whether that's expected.
More extensive hyperparameter search seems in order too ..
On Tue, 3 Oct 2023, 17:43 Tomás Soares da Costa, @.***> wrote:
Apologies for the delay, only this week that I managed to dedicate some time onto the problem. Below are some results for the breast cancer dataset (just cast into the complex domain), with parameters equal to your simple benchmark example. I will further evaluate the consideration of other datasets but some initial comments to ensure I made no (initial mistakes):
- Since we are evaluating the attention, I guess we should just look at the global and local attention after training?
- Evidently the loss function is very different as there are noteworthy differences in the network (as we utilized split A function, the number of parameters increased (?)). The local attention seems quite sparse, is this expected (vertically)?
- Global attention appears different, which is expected seeing as they are two different networks?
- When I utilize my own electromagnetic datasets I also get diverging results but, for its context, they are inlined with the expected. I dont have much sensisivity for this dataset so I cant really tell if this is a colinear outcome.
- Any suggestions or helpfull tips I could check? 😅
(Upper Image is standard Real SAN and lower is complex equivalent, see titles) Global Attention: [image: image] https://user-images.githubusercontent.com/68435474/272314978-51e5ac41-da01-4edd-8a46-6ffba85127b6.png [image: image] https://user-images.githubusercontent.com/68435474/272315234-f1acc7cd-e6ec-4c1a-b906-ad79761bdea3.png
Local Attention: [image: image] https://user-images.githubusercontent.com/68435474/272315125-e420162b-b0a6-4fa7-96d1-f4d8ba11578b.png [image: image] https://user-images.githubusercontent.com/68435474/272315260-2e86bd51-82bc-4205-b2ed-b2c2d541e727.png
Loss: [image: image] https://user-images.githubusercontent.com/68435474/272315166-10353305-a543-4c0d-b2d5-a210a23c0bd1.png [image: image] https://user-images.githubusercontent.com/68435474/272315310-f69acffc-4384-438e-aef4-969dffd05cc3.png
Loss: [image: image] https://user-images.githubusercontent.com/68435474/272315637-124d3558-3ab3-4e59-b03a-9749a4bee1dd.png [image: image] https://user-images.githubusercontent.com/68435474/272315588-3a32d57d-e42c-431d-a15c-40eaea891cad.png
— Reply to this email directly, view it on GitHub https://github.com/SkBlaz/san/issues/8#issuecomment-1745243081, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACMSERHSFEKEKIYD7DKEPCLX5QXBTAVCNFSM6AAAAAA5H4RDJ6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTONBVGI2DGMBYGE . You are receiving this because you commented.Message ID: @.***>
Hi again,
Apologies for the late reply but I had other deadlines and classes! 😅
Still, I have read some more works on the topic of complex-valued NNs and this seems like a fitting strategy! I will run some more simulations to address overfitting. I read again that you used ten fold stratified cross validation to select the best model before evaluating the feature importance but how much did you change parameters, extensive grid search? If so did you use Ray tune and how much did you change the configuration?
Thanks again for the help.
Best, Tomás
Hi! Great to hear from you,
It was grid yes, Ray tune is a very good idea/upgrade
On Tue, 28 Nov 2023, 12:56 Tomás Soares da Costa, @.***> wrote:
Hi again,
Apologies for the late reply but I had other deadlines and classes! 😅
Still, I have read some more works on the topic of complex-valued NNs and this seems like a fitting strategy! I will run some more simulations to address overfitting. I read again that you used ten fold stratified cross validation to select the best model before evaluating the feature importance but how much did you change parameters, extensive grid search? If so did you use Ray tune and how much did you change the configuration?
Thanks again for the help.
Best, Tomás
— Reply to this email directly, view it on GitHub https://github.com/SkBlaz/san/issues/8#issuecomment-1829694011, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACMSERH2ADJJFAUOB7YL5KTYGXGOLAVCNFSM6AAAAAA5H4RDJ6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTQMRZGY4TIMBRGE . You are receiving this because you commented.Message ID: @.***>
Hi again Blaz,
Great news, I managed to implement RayTune and indeed it works for both real- and complex-valued outputs. Ultimately, the main question remains the same. How to interpret the complex-valued weights, i.e., how to properly formulate the self attention. Most of the literature uses the Magnitude to output the attention but I do wonder whether that does not make the use of complex-values negligible. Dont suppose you might know of some papers discussing other attention strategies?
Either way I want to ask the last details: So I take my tabular data, apply ray tune to get the optimal hyperparameter combination and then should I train the model with the whole dataset or perform Kfold CV (I feel its redundant with RayTune)? My guess is that for the with the whole dataset I might overfit but that's good no? Do I have to apply an external classifier to confirm that the top selected features provide the best accuracy (Seems redundant)? Did you compare the SAN with the TabNet implementation?
Apologies again for the lengthy set of questions!
Hi! Great to hear from you. So:
- TabNet was not out when we pushed san paper
- interpretation is normally magnitude, not aware of alternatives
- raytune always on train data. So, train data gives you classifier and rankings.
On Mon, 11 Dec 2023, 18:01 Tomás Soares da Costa, @.***> wrote:
Hi again Blaz,
Great news, I managed to implement RayTune and indeed it works for both real- and complex-valued outputs. Ultimately, the main question remains the same. How to interpret the complex-valued weights, i.e., how to properly formulate the self attention. Most of the literature uses the Magnitude to output the attention but I do wonder whether that does not make the use of complex-values negligible. Dont suppose you might know of some papers discussing other attention strategies?
Either way I want to ask the last details: So I take my tabular data, apply ray tune to get the optimal hyperparameter combination and then should I train the model with the whole dataset or perform Kfold CV (I feel its redundant with RayTune)? My guess is that for the with the whole dataset I might overfit but that's good no? Do I have to apply an external classifier to confirm that the top selected features provide the best accuracy (Seems redundant)? Did you compare the SAN with the TabNet implementation?
Apologies again for the lengthy set of questions!
— Reply to this email directly, view it on GitHub https://github.com/SkBlaz/san/issues/8#issuecomment-1850489618, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACMSEREJ3ACYHSVGCLNZW2TYI436VAVCNFSM6AAAAAA5H4RDJ6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTQNJQGQ4DSNRRHA . You are receiving this because you commented.Message ID: @.***>
@SantaTitular pretty sure TabNet doesn't scale same way this architecture does fyi