PyMFEM icon indicating copy to clipboard operation
PyMFEM copied to clipboard

Performance issues w/ mesh attribute retrieval

Open tradeqvest opened this issue 1 year ago • 3 comments

Hello,

For my application, I constantly need to retrieve mesh element attributes, i.a. mesh.GetElementTransformation(i) or mesh.GetElementVertices(i). As this requires looping over each element, the performance suffers significantly. Is there any way to do it more efficiently that I am overlooking? Is there a way to vectorize the retrieval?

I would appreciate any insights! Thanks in advance for your time and help!

tradeqvest avatar Mar 21 '24 15:03 tradeqvest

As for GetElementVertices, there is Mesh::GetVertexToElementTable. This returns a mapping from Vertex to Element as a table. Using I and J array of this table, you can create a reverse mapping from Element to Vertex. In the following, I construct scipy.sparse.csr_matrix from I and J. Then, I took transpose and tocsr You can use the indices and indptr of resultant array as the mapping from element to vertices.

import numpy as np
import mfem.ser as mfem
from scipy.sparse import csr_matrix

mesh = proj.model1.mfem.variables.eval("mesh")
tb = mesh.GetVertexToElementTable()
i = mfem.intArray((tb.GetI(), mesh.GetNV())).GetDataArray()
i = np.hstack((i, tb.Size_of_connections())) # need to append the total length
j = mfem.intArray((tb.GetJ(), tb.Size_of_connections())).GetDataArray()
mat = csr_matrix(([1]*len(j), j, i)).transpose().tocsr()

# well.. let's check if this is correct ;D
for i in range(mesh.GetNE()):
   iverts = mat.indices[mat.indptr[i]:mat.indptr[i+1]]
   iverts2 = mesh.GetElementVertices(i)
   if np.any(np.sort(iverts) != np.sort(iverts2)):
      print("error", i, iverts, iverts2)

As for '''mesh.GetElementTransformation(i)''', I realized that it calls Tr = IsoparametricTransformation() every time, meaning it creates this object every time. We could change the wrapper so that we can pass Tr as a keyword argument, if this object allocation is an issue. If not, I am not sure if there is a simple way to make this faster.

sshiraiwa avatar Mar 24 '24 02:03 sshiraiwa

Thank you very much for your answer! 🙂 The first part worked really well!

Regarding the speed up of mesh.GetElementTransformation(i), I want to speed it up for this method:

def interpolate_solution_at_points(
    fespace, mesh, solution, integration_points, corresponding_elements
):
    """
    Interpolate a finite element solution at given points.

    Args:
    - fespace: The finite element space (mfem.FiniteElementSpace)
    - mesh: The mesh (mfem.Mesh)
    - solution: The finite element solution (np.array)
    - points: The points where the solution is to be interpolated (numpy array of shape (n_points, dim))

    Returns:
    - interpolated_values: The interpolated solution values at the given points (numpy array)
    """
    dim = fespace.GetMesh().Dimension()
    assert (
        integration_points.shape[1] == dim
    ), "Dimension of points must match the mesh dimension"
    grid_function = GridFunction(fespace)
    grid_function.Assign(np.ravel(solution))
    n_points = integration_points.shape[0]
    interpolated_values = np.zeros(n_points)
    ip = IntegrationPoint()
    for i, elem in enumerate(corresponding_elements):
        trans = mesh.GetElementTransformation(elem)
        point = Vector(integration_points[i, :])
        trans.TransformBack(point, ip)
        interpolated_values[i] = grid_function.GetValue(elem, ip)
    return interpolated_values.reshape(-1, 1)

If you see a way to make it more efficient, please let me know! 🙂 Thank you in advance for your time and effort!

tradeqvest avatar Mar 24 '24 15:03 tradeqvest

Hi @tradeqvest

What is the size of the problem you are working with?

  • How big is your for loop (how many points are you interpolating over)?
  • How many elements are in your mesh

The reason I ask is because if Nel << Npoints it may be worthwhile, as a first pass, to construct a mapping of your transformations for all elements, then access them in the for loop, rather than reinitializing.

I ran a quick profile and it looks like although mesh.GetElementTransformation does take some time, a lot of the time was in the initialization of Vector. Perhaps you could construct a single Vector before your loop, and change the values in the loop.

I'm not aware of a vectorized solution (maybe @sshiraiwa) might know. Could you try those two things and see if it improves your speed?

justinlaughlin avatar Apr 23 '24 18:04 justinlaughlin