qiskit-machine-learning icon indicating copy to clipboard operation
qiskit-machine-learning copied to clipboard

Dimension Mismatch in torch_connector.py

Open miles0428 opened this issue 1 year ago • 4 comments

Environment

  • Qiskit Machine Learning version: 0.7.0
  • Python version: 3.9.6
  • Operating system: macOS

What is happening?

I am attempting to build a quantum version of a convolutional layer with Qiskit and PyTorch, but I encountered an error related to the Einstein summation method. I believe this issue arises because the dimensions do not match when I execute loss.backward().

The problem specifically stems from the fact that the torch_connector.py file utilizes weights_grad = torch.einsum("ij,ijk->k", grad_output.detach().cpu(), weights_grad) for the 3D case. However, in my implementation, both grad_output.detach().cpu() and weights_grad are in the 4D case. To resolve this, I modified the expression from "ij,ijk->k" to "ijl,ijlk->k", and this corrected the problem.

How can we reproduce the issue?

Quantum Convolution Layer

import torch
import qiskit as qk
from qiskit import QuantumCircuit
import torch.nn as nn
from qiskit_machine_learning.neural_networks import SamplerQNN
# from torch_connector import TorchConnector
from qiskit_machine_learning.connectors import TorchConnector
import torch.nn.functional as F


# Define the quantum circuit

class Quanv2d(nn.Module):
    def __init__(self,input_channel,output_channel,num_qubits,num_weight, kernel_size=3, stride=1):
        super(Quanv2d, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.input_channel = input_channel
        self.output_channel = output_channel
        self.qnn = TorchConnector(self.Sampler(num_weight,kernel_size * kernel_size * input_channel, num_qubits))
        #check if 2**num_qubits is greater than output_channel
        assert 2**num_qubits >= output_channel, '2**num_qubits must be greater than output_channel'

    def Sampler(self, num_weights, num_input, num_qubits = 3):
        qc = QuantumCircuit(num_qubits)
        weight_params = [qk.circuit.Parameter('w{}'.format(i)) for i in range(num_weights)]
        input_params = [qk.circuit.Parameter('x{}'.format(i)) for i in range(num_input)]
        #construct the quantum circuit with the parameters
        """
            build quantum circuit here
        """

        #use SamplerQNN to convert the quantum circuit to a PyTorch module
        qnn = SamplerQNN(circuit = qc,weight_params = weight_params,interpret=self.interpret, input_params=input_params,output_shape=self.output_channel)
        return qnn

    def interpret(self, X):
        return X%self.output_channel

    def forward(self, X):
        #for each input channel we have a quantum circuit to process it
        #and then we add them together
        #get the height and width of the output tensor
        height = len(range(0,X.shape[2]-self.kernel_size+1,self.stride))
        width = len(range(0,X.shape[3]-self.kernel_size+1,self.stride))
        output = torch.zeros((X.shape[0],self.output_channel,height,width))
            
        X = F.unfold(X[:, :, :, :], kernel_size=self.kernel_size, stride=self.stride)
        # print(X.shape)
        qnn_output = self.qnn(X.permute(2, 0, 1)).permute(1, 2, 0)
        qnn_output = torch.reshape(qnn_output,shape=(X.shape[0],self.output_channel,height,width))
        output += qnn_output
        return output

Torch Model

lass HybridQNN(nn.Module):
    def __init__(self):
        super(HybridQNN, self).__init__()
        #build a full classical convolutional layer
        self.conv1 = nn.Conv2d(1, 1, 3)
        self.bn1 = nn.BatchNorm2d(1)
        self.sigmoid = nn.Sigmoid()
        self.maxpool1 = nn.MaxPool2d(2)
        self.conv2 = Quanv2d(1, 2, 2, 3,kernel_size=4,stride=3)
        self.bn2 = nn.BatchNorm2d(2)
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d(2)
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(32, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.sigmoid(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.flatten(x)
        x = self.linear(x)
        return x

Error code

RuntimeError: einsum(): the number of subscripts in the equation (2) does not match the number of dimensions (3) for operand 0 and no ellipsis was given

What should happen?

"ij,ijk->k" should match the dimension of these two parameter(grad_output.detach().cpu() and weights_grad ).

Any suggestions?

maybe add some Discriminant to get the dimension of Discriminant for example


shape_grad_out = grad_output.detach().cpu().shape
shape_weights_grad = weights_grad.shape
text1 = ''
text2 = ''
text3 = '' 

# Process the shape of grad_output
for dim_size in shape_grad_out:
    '''
    Perform operations and update text1 accordingly
    '''

# Process the shape of weights_grad
for dim_size in shape_weights_grad:
    '''
    Perform operations and update text2 accordingly
    '''

# Use the updated text1, text2, and text3 in the einsum function
weights_grad = torch.einsum(f"{text1},{text2}->{text3}", grad_output.detach().cpu(), weights_grad)

miles0428 avatar Nov 13 '23 14:11 miles0428

Thanks for pointing this out @miles0428. Perhaps TorchConnector was designed to be coupled to image-type datasets, which would explain the current summation form. This should indeed be matched to the layer calculations. To add to your suggestion, something like

text2 = ''
char_limit = 26
for i in range(27):
    text2 += chr(97 + i)  # chr(97) == 'a'
    if i >= char_limit:
        raise RuntimeError

would make the einsum implementation general enough.

edoaltamura avatar Mar 08 '24 15:03 edoaltamura

@miles0428 could you please provide the exact code you used to instantiate the HybridQNN class? A simple call HybridQNN() returns an error, but not the one you described. Could you also provide a minimal example that reproduces the dim-mismatch error with the 4D case?

edoaltamura avatar Mar 08 '24 17:03 edoaltamura

Hi @edoaltamura, here is the code where you can reproduce the error. I filled in the skipped program in Quanv2d.py to make the error reproducible.

import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from typing import  Union, List, Iterator
import qiskit as qk
from qiskit import QuantumCircuit
from qiskit_machine_learning.neural_networks import SamplerQNN
from qiskit_machine_learning.connectors import TorchConnector

# from torch_connector import TorchConnector


class Quanv2d(nn.Module):

    '''
        A quantum convolutional layer
        args
            input_channel: number of input channels
            output_channel: number of output channels
            num_qubits: number of qubits
            num_weight: number of weights
            kernel_size: size of the kernel
            stride: stride of the kernel
    '''
    def __init__(self,
                 input_channel : int,
                 output_channel : int,
                 num_qubits : int,
                 num_weight : int, 
                 kernel_size : int = 3, 
                 stride : int = 1
                 ):

        super(Quanv2d, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.input_channel = input_channel
        self.output_channel = output_channel
        self.num_weight = num_weight
        self.num_input = kernel_size * kernel_size * input_channel
        self.num_qubits = num_qubits
        self.qnn = TorchConnector(self.Sampler())
        assert 2**num_qubits >= output_channel, '2**num_qubits must be greater than output_channel'

    def build_circuit(self,
                num_weights : int,
                num_input : int,
                num_qubits : int = 3
                ) -> tuple[QuantumCircuit, Iterator[qk.circuit.Parameter], Iterator[qk.circuit.Parameter]]:
        '''
        build the quantum circuit
        param
            num_weights: number of weights
            num_input: number of inputs
            num_qubits: number of qubits
        return
            qc: quantum circuit
            weight_params: weight parameters
            input_params: input parameters
        '''
        qc = QuantumCircuit(num_qubits)
        weight_params = [qk.circuit.Parameter('w{}'.format(i)) for i in range(num_weights)]
        input_params = [qk.circuit.Parameter('x{}'.format(i)) for i in range(num_input)]
        #construct the quantum circuit with the parameters
        for i in range(num_qubits):
            qc.h(i)
        for i in range(num_input):
            qc.ry(input_params[i]*2*torch.pi, i%num_qubits)
        for i in range(num_qubits - 1):
            qc.cx(i, i + 1)
        for i in range(num_weights):
            qc.rx(weight_params[i]*2*torch.pi, i%num_qubits)
        for i in range(num_qubits - 1):
            qc.cx(i, i + 1)
        return qc, weight_params, input_params
    
    def Sampler(self) -> SamplerQNN:
        '''
        build the quantum circuit
        param
            num_weights: number of weights
            num_input: number of inputs
            num_qubits: number of qubits
        return
            qc: quantum circuit
        '''
        qc,weight_params,input_params = self.build_circuit(self.num_weight,self.num_input,3)
        
        #use SamplerQNN to convert the quantum circuit to a PyTorch module
        qnn = SamplerQNN(
                        circuit = qc,
                        weight_params = weight_params,
                        interpret=self.interpret, 
                        input_params=input_params,
                        output_shape=self.output_channel,
                         )
        return qnn

    def interpret(self, X: Union[List[int],int]) -> Union[int,List[int]]:
        '''
        interpret the output of the quantum circuit using the modulo function
        this function is used in SamplerQNN
        args
            X: output of the quantum circuit
        return
            the remainder of the output divided by the number of output channels
        '''
        return X % self.output_channel

    def forward(self, X : torch.Tensor) -> torch.Tensor:
        '''
        forward function for the quantum convolutional layer
        args
            X: input tensor with shape (batch_size, input_channel, height, width)
        return
            X: output tensor with shape (batch_size, output_channel, height, width)
        '''
        height = len(range(0,X.shape[2]-self.kernel_size+1,self.stride))
        width = len(range(0,X.shape[3]-self.kernel_size+1,self.stride))
        output = torch.zeros((X.shape[0],self.output_channel,height,width))
        X = F.unfold(X[:, :, :, :], kernel_size=self.kernel_size, stride=self.stride)
        qnn_output = self.qnn(X.permute(2, 0, 1)).permute(1, 2, 0)
        qnn_output = torch.reshape(qnn_output,shape=(X.shape[0],self.output_channel,height,width))
        output += qnn_output
        return output
 
if __name__ == '__main__':
    # Define the model
    model = Quanv2d(3, 1, 3, 3,stride=1)

    X = torch.rand((2,3,6,6))
    X.requires_grad = True

    X1 = model.forward(X)
    X1 = torch.sum(X1)
    #do some backward test
    X1.backward()

miles0428 avatar Mar 12 '24 07:03 miles0428

Thanks @miles0428, the last example was very useful. I've allowed the einsum signature to be computed dynamically, which fixes the issue. The fix is currently in a Dev branch (see above), and we aim to roll it out in Main in the next release at the latest.

edoaltamura avatar Mar 15 '24 15:03 edoaltamura