cuda-quantum icon indicating copy to clipboard operation
cuda-quantum copied to clipboard

[RFC] Improve the internal data implementation for cudaq::state

Open amccaskey opened this issue 1 year ago • 3 comments

Background

Currently, state just wraps a tuple containing the shape of the state data and a vector to the state data. This is sub-optimal because it forces the creation of copies, especially for state vector data managed by the GPU.

We need to rethink the internal data representation for cudaq::state to remove this tuple and replace with a new data type specific to simulation backends (i.e. some interface that backend simulators must implement) that allows cudaq::state to track the state data in a backend-native manner, i.e. holding the GPU device pointer.

Proposal

This RFC introduces the following new types and type changes:

namespace cudaq {

/// @brief state_data is a variant type
/// encoding different forms of user state vector data
/// we support.
using state_data = std::variant<
    std::vector<cudaq::complex>,
    std::pair<cudaq::complex *, std::size_t>,
    std::vector<cudaq::complex *>, std::shared_ptr<TensorNetworkState>>;

class SimulationState {
public: 

  /// @brief Simulation states are made up of a
  /// vector of Tensors. Each tensor tracks its data pointer
  /// and the tensor extents.
  struct Tensor {
    cudaq::complex *data = nullptr;

    std::vector<std::size_t> extents;
    precision fp_precision;
    std::size_t get_rank() const { return extents.size(); }
    std::size_t get_num_elements() const;
    std::size_t element_size() const ;
  };

  /// @brief Create a new subclass specific SimulationState
  /// from the user provided data set.
  virtual std::unique_ptr<cudaq::SimulationState>
  createFromData(const state_data &data) ;

  /// @brief Return the tensor at the given index. Throws
  /// for an invalid tensor index.
  virtual Tensor getTensor(std::size_t tensorIdx = 0) const = 0;

  /// @brief Return all tensors that represent this state
  virtual std::vector<Tensor> getTensors() const = 0;

  /// @brief Return the number of tensors that represent this state.
  virtual std::size_t getNumTensors() const = 0;

  /// @brief Return the number of qubits this state represents.
  virtual std::size_t getNumQubits() const = 0;

  /// @brief Compute the overlap of this state representation with
  /// the provided `other` state, e.g. `<this | other>`.
  virtual cudaq::complex overlap(const SimulationState &other) = 0;

  /// @brief Return the amplitude of the given computational
  /// basis state.
  virtual cudaq::complex
  getAmplitude(const std::vector<int> &basisState) = 0;

  // @brief Return the floating point precision used by the simulation state.
  virtual precision getPrecision() const = 0;

  /// @brief Destroy the state representation, frees all associated memory.
  virtual void destroyState() = 0;

  /// @brief Return the element from the tensor at the
  /// given tensor index and at the given indices.
  virtual cudaq::complex 
  operator()(std::size_t tensorIdx, const std::vector<std::size_t> &indices) ;

  /// @brief Return the number of elements in this state representation.
  /// Defaults to adding all shape elements.
  virtual std::size_t getNumElements() const ;

  /// @brief Return true if this `SimulationState` wraps data on the GPU.
  virtual bool isDeviceData() const ;

  /// @brief Return true if this `SimulationState` wraps contiguous memory
  /// (array-like).
  //  If true, `operator()` can be used to index elements in a multi-dimensional
  //  array manner.
  // Otherwise, `getAmplitude()` must be used.
  virtual bool isArrayLike() const;

  /// @brief Transfer data from device to host, return the data
  /// to the pointer provided by the client. Clients must specify the number of
  /// elements.
  virtual void toHost(cudaq::complex *clientAllocatedData,
                      std::size_t numElements) const;

  /// @brief Destructor
  virtual ~SimulationState() {}
};

using tensor = SimulationState::Tensor; 

class state {
private:
  std::unique_ptr<SimulationState> internal;

public:
  state(SimulationState *ptrToOwn); 
  
  /// @brief Convenience function for extracting from a known vector.
  cudaq::complex operator[](std::size_t idx);

  /// @brief Convenience function for extracting from a known matrix.
  cudaq::complex operator()(std::size_t idx, std::size_t jdx);

  /// @brief General extraction operator for state data.
  cudaq::complex operator()(const std::initializer_list<std::size_t> &,
                                  std::size_t tensorIdx = 0);

  /// @brief Return the tensor at the given index for this state representation.
  /// For state-vector and density matrix simulation states, there is just one
  /// tensor with rank 1 or 2 respectively.
  tensor get_tensor(std::size_t tensorIdx = 0) const;

  /// @brief Return all tensors that represent this simulation state.
  std::vector<tensor> get_tensors() const;

  /// @brief Return the number of tensors that represent this state.
  std::size_t get_num_tensors() const;

  /// @brief Return the underlying floating point precision for this state.
  cudaq::simulation_precision get_precision() const;

  /// @brief Return true if this a state on the GPU.
  bool is_on_gpu() const;

  /// @brief Copy this state from device to
  /// FIXME should be a vector<Tensor> right?
  void to_host(cudaq::complex *hostPtr,
               std::size_t numElements) const;

  /// @brief Create a new state from user-provided data.
  /// The data can be host or device data.
  static state from_data(const state_data &data);

  ~state();
};

cudaq::complex overlap(const state& other); 
cudaq::complex amplitude(const std::vector<int>& basisState); 

}

Concrete CircuitSimulators are then responsible for subclassing SimulationState and returning as a pointer.

Examples

    kernel = cudaq.make_kernel()
    qubits = kernel.qalloc(2)
    kernel.h(qubits[0])
    kernel.cx(qubits[0], qubits[1])

    # Get the quantum state, which should be a vector.
    got_state = cudaq.get_state(kernel)

    # Data type needs to be the same as the internal state vector
    want_state = np.array([1. / np.sqrt(2.), 0., 0., 1. / np.sqrt(2.)],
                          dtype=dtype)

    # Check the indexing operators on the State class
    # while also checking their values
    np.isclose(want_state[0], got_state[0].real)
    np.isclose(want_state[1], got_state[1].real)
    np.isclose(want_state[2], got_state[2].real)
    np.isclose(want_state[3], got_state[3].real)

    # Check the entire vector with numpy.
    got_vector = np.array(got_state, copy=False)
    for i in range(len(want_state)):
        assert np.isclose(want_state[i], got_vector[i])
        # if not np.isclose(got_vector[i], got_vector_b[i]):
        print(f"want = {want_state[i]}")
        print(f"got = {got_vector[i]}")
    assert np.allclose(want_state, np.array(got_state))

    # Check overlaps.
    # Check the overlap overload with want numpy array.
    assert np.isclose(got_state.overlap(want_state), 1.0)

    kernel = cudaq.make_kernel()
    q = kernel.qalloc(2)
    kernel.h(q[0])
    kernel.cx(q[0], q[1])

    # State is on the GPU, this is nvidia target
    state = cudaq.get_state(kernel)
    state.dump()
    # Create a state on GPU
    expected = cp.array([.707107, 0, 0, .707107], dtype=cp.complex128)
    # Compute the overlap
    result = state.overlap(expected)

    cp_data = cp.array([.707107, 0, 0, .707107], dtype=cp.complex64)
    state_from_cupy = cudaq.State.from_data(cp_data)
    state_from_cupy.dump()
    kernel = cudaq.make_kernel()
    q = kernel.qalloc(2)
    kernel.h(q[0])
    kernel.cx(q[0], q[1])
    # State is on the GPU, this is nvidia target
    state = cudaq.get_state(kernel)
    result = state.overlap(state_from_cupy)
    assert np.isclose(result, 1.0, atol=1e-3)

   @cudaq.kernel
    def bell():
        qubits = cudaq.qvector(2)
        h(qubits[0])
        x.ctrl(qubits[0], qubits[1])

    # Get the quantum state, which should be a vector.
    got_state = cudaq.get_state(bell)

    want_state = cudaq.State.from_data(np.array([1. / np.sqrt(2.), 0., 0., 1. / np.sqrt(2.)],
                          dtype=np.complex128))
    
    assert len(want_state) == 4 

    # Check the indexing operators on the State class
    # while also checking their values
    np.isclose(want_state[0], got_state[0].real)
    np.isclose(want_state[1], got_state[1].real)
    np.isclose(want_state[2], got_state[2].real)
    np.isclose(want_state[3], got_state[3].real)

    # Check the entire vector with numpy.
    got_vector = np.array(got_state, copy=False)
    for i in range(len(want_state)):
        assert np.isclose(want_state[i], got_vector[i])
        # if not np.isclose(got_vector[i], got_vector_b[i]):
        print(f"want = {want_state[i]}")
        print(f"got = {got_vector[i]}")
    assert np.allclose(want_state, np.array(got_state))
    cudaq.reset_target()

TODO

Before merging to main:

  • [x] SimulationState impls for tensornet, tensornet-mps, mgpu

  • [x] getAmplitude implementations

  • [x] overlap() and amplitude() API functions

  • [ ] Cloud support for overlap() and amplitude() (https://github.com/NVIDIA/cuda-quantum/pull/1653)

  • [ ] from_data for multiple tensors (@1tnguyen)

  • [x] std::vector/numpy array qvector initialization

  • [ ] cudaq::state qvector initialization

    • [x] Library mode
    • [ ] C++ MLIR mode
    • [ ] Python Kernel mode (https://github.com/NVIDIA/cuda-quantum/pull/1706)
    • [ ] Python Builder mode
  • [ ] qvector initialization for hardware backends

Post merging to main

  • [ ] Tests / demonstrate MQPU and copying GPU device data (MPI and single node)
  • [ ] Examples
  • [ ] Documentation

amccaskey avatar Jan 18 '24 16:01 amccaskey