crystal_map visualization discussion
I've spent some time over the past few weeks working on this with limited success. Raising an issue here in hopes that some other people can weigh in with alternatives or advice.
Short version
orix.CrystalMap.plot is fast but restrictive. A mesh representation would be more ideal, but at the cost of some speed. I tried several methods, of which I listed the three best below with example code:
- PatchCollection seems easiest to implement
- if someone can figure out an irregular polygon version of 'pcolor', I think that would be faster, but it might not exist.
- Plotly gives really powerful interactive plots, but without some direct JavaScript code, its prohibitively slow to set up.
The issue:
orix.CrystalMap currently plots EBSD images using matplotlib.pyplot.imshow. This is fast and intuitive, but has some drawbacks:
- it can only plot regular square grids.
- it requires the
crystal_map.rotationsto be in the form of a correctly ordered 2d array. - while not necessary, it becomes more convenient to do flips or edits by flipping and/or transposing the 'crystal_map.rotations' variable. this can be problematic if the user is storing a lot of additional properties information.
- subsections like single grains or phases cut from the larger EBSD images are usually masks of the original. this can lead to memory overloads if, say, the user wanted 1000 crystalmap objects in a list, with one for each of 1000 grains in the CrystalMap.
The General solution:
EBSD maps should instead be plotted using some type of mesh tool. This is what MTEX does. For example, run the following MATLAB code, and pause the code in "generateUnitCells.m" (github link here)
clear all
close all
% load some dummy data
mtexdata alphaBetaTitanium
%plot it
plot(ebsd('Ti (alpha)'),ebsd('Ti (alpha)').orientations)
You can see that each "pixel" is really a copy of a polygon based on ebsd.unitCell with a transposed centroid. This has some important advantages:
- it can plot square, rectangular, and hex grids using the exact same method.
- crystal_map rotations can be 1d, 2d, 3d, or nd, and in whatever order the user chooses, as long as there is an equivalent of
ebsd.prop.xandebsd.prop.ywith the correct centroids. - plots can be easily rotated to weird angles and still look correct (super useful for my work aligning datasets, not sure how relevant for everyone else)
- There is a logical extension of this method to plotting irregular polygons. IE, plotting routines can be used for both pixels and grain maps.
- There is also a logical extension of this for 3d meshes. see plotly for some fun interactive examples
All that said, any mesh-based solution will inherently be slower than imshow. Thus, some considerable thought should be put into the best implementation.
Possible Implementation 1: Matplotlib PatchCollection
This seems like the closest parallel to what MTEX already does. I tried several methods and various implementations of jit, numba, multithreading, GPU acceleration, etc, but it seems the fastest method is the one used on the backend of hexbins
- define a single Polygon object that will be used for each pixel
- Pass the polygon into a
PatchCollectionwithoffsets=xy_centroids
Using dummy data, this looks a bit like this:
import numpy as np
from matplotlib.collections import PatchCollection, PolyCollection
from matplotlib.patches import RegularPolygon
import matplotlib.pyplot as plt
from matplotlib.cm import viridis
from matplotlib.transforms import AffineDeltaTransform as ADT
from matplotlib.transforms import Affine2D
import copy
# create a grid of centroids
spacing_1d = np.arange(1000, dtype=float)
xx, yy = np.meshgrid(spacing_1d, spacing_1d)
# square grid
sq_points = np.stack([xx.flatten(), yy.flatten()]).T
# hex grid
xx[::2, :] = xx[::2,:] +0.5
xx = xx*2/np.sqrt(3)
hex_points = np.stack([xx.flatten(), yy.flatten()]).T
# define unit polygons
single_sq_poly = RegularPolygon([0, 0], 4, radius=1**0.5, orientation=np.pi/4)
single_hex_poly = RegularPolygon([0, 0], 6, radius=1/np.sqrt(3), orientation=np.pi/3)
# make deep copies because plotting changes vector locations
sq_poly = copy.deepcopy(single_sq_poly)
hex_poly = copy.deepcopy(single_hex_poly)
# plot of single patches by themselves
fig, ax = plt.subplots(2, 2)
ax[0,0].set_xlim(-2, 2)
ax[0,0].set_ylim(-2, 2)
ax[0,0].set_aspect(1)
ax[0,0].add_patch(single_sq_poly)
ax[1,0].set_xlim(-2, 2)
ax[1,0].set_ylim(-2, 2)
ax[1,0].set_aspect(1)
ax[1,0].add_patch(single_hex_poly)
# define grid of patches
# create grids of patches
sq_grid = PatchCollection([sq_poly],offsets=sq_points,offset_transform=ADT(ax[0,1].transData))
hex_grid = PatchCollection([hex_poly],offsets=hex_points,offset_transform=ADT(ax[1,1].transData))
# these can be colored either by setting values;
sq_grid.set_array(np.random.rand(sq_points.size))
# or by defining rgb colors
c = viridis(np.random.randn(hex_points.size))
hex_grid.set_array(None)
hex_grid.set_color(c)
# plot of patch map
ax[0,1].set_xlim(-5, 205)
ax[0,1].set_ylim(-5, 205)
ax[0,1].set_aspect(1)
ax[0,1].add_collection(sq_grid)
#ax[0,1].scatter(sq_points[:,0],sq_points[:,1],)
ax[1,1].set_xlim(-5, 235)
ax[1,1].set_ylim(-5, 205)
ax[1,1].set_aspect(1)
ax[1,1].add_collection(hex_grid)
ax[0,0].set_title('single square patch')
ax[1,0].set_title('single hex patch')
ax[0,1].set_title('1000x1000 square intensity grid')
ax[1,1].set_title('1000x1000 hex rgb grid')
plt.tight_layout()
result looks like this, plots 2 million polygons in roughly 2-4 seconds (roughly comprable to MTEX)
The actual calculation is nearly instantaneous, it seems the holdup is in Matplotlib's plotting ability. Only real downside here is the plots are sluggish during zooming/panning, and adding tooltips will only slow this down further.
Possible Implementation 2: Plotting a triangle mesh using Matplotlib tripcolor
Here is a good demo of tripcolor in action and here is a snippet of code I wrote using it on the same dataset as above
import matplotlib.tri as tri
triang = tri.Triangulation(sq_points[:,0],sq_points[:,1])
fig,ax =plt.subplots(1)
ax.tripcolor(triang,np.random.rand(triang.triangles.shape[0]))
this seems to take much longer to calculate, but gives a more responsive plot. unfortunately, this ONLY works with TriMeshes (ie, meshes of triangles), which creates 2x as many faces for square grids and 6x for hex grids. This also means some additional book keeping has to be done.
There is also a quad mesh which seems to plot faster, but of course cannot do hexagonal plots. Also, these methods to not translate nicely to the concept of grain maps. (example of quad plot here)![https://matplotlib.org/stable/gallery/images_contours_and_fields/quadmesh_demo.html#sphx-glr-gallery-images-contours-and-fields-quadmesh-demo-py]
This FEELS like there should be a better way to make similar pcolor plots with irregular polygons, but I cannot figure it out. If someone else can, feel free to give it a go.
Possible Implementation 3: Plotly
For those who have not seen, you can do some incredibly cool interactive plots with plotly. In particular, an EBSD scan is essentially just what GIS people call a (Chloropleth, and plotly has some insanely fast tools for rendering them)[https://plotly.com/python/mapbox-county-choropleth/]
However, this has two downsides.
- Plotly creates .html objects, which work well with Jupyter, not so well with Spyder or VS Code.
- Plotly.py is, at its core, a bunch of code that writes JSON files which is then read by Plotly.js.
The first one seems acceptable, but the second one has been a problem. I can make super responsive 2000x2000 pixel ebsd plots and even 500x500x500 3d meshes, but writing them takes 5-10 minutes, because the python code is writing the data out as ASCII text. The fastest method I found was to build Mesh3D objects like this:
import plotly.graph_objects as go
from plotly.offline import plot
xx, yy = np.meshgrid(np.arange(400),np.arange(400))
x = xx.flatten()
y = yy.flatten()
points = np.stack([x,y]).T
width = xx.shape[0]
f_size = width*(width-1)
f1 = np.array([[i, i+1, i+width] for i in np.arange(f_size)], dtype=np.int64)
f1 = f1[(f1[:,0]+1)%400 !=0]
f2 = np.array([[i+width+1, i+1, i+width] for i in np.arange(f_size) ], dtype=np.int64)
f2 = f2[(f2[:,2]+1)%400 !=0]
faces = np.vstack([f1, f2])
m = go.Mesh3d(x=x,
y=y,
z = x*0,
i=faces[:,0].tolist(),
j=faces[:,1].tolist(),
k=faces[:,2].tolist(),
facecolor = cmap.viridis(x*np.pi%1),
showscale=False
)
# fig1 = go.Figure(data=m)
# plot(fig1, auto_open=True)
n = go.Mesh3d(x=x,
y=y,
z = x*0,
i=[4,1],
j=[1,4],
k=[0,5],
facecolor = cmap.viridis(x*np.pi%1),
showscale=False
)
fig1 = go.Figure(data=m)
plot(fig1, auto_open=True)
Plots look like this. you get the tooltips for free, and can code in whatever information you want into them (rgb, euler angles, quaternion, axis-angle, etc)
Of course, this again brings up the problem of representing singular patches as multiple triangles.
I think there may be a more generic way to do this using Plotly.go.Scatter ,where each pixel is a different Scatter object. However, this shouldn't be done in Python because it is prohibitively slow to write out the necessary JSON. THAT SAID, if someone knowledgeable in JS has interest, this seems like a much better plotting tool for what we want.
Here is a very rough example for creating a50x50 grid.
import plotly.graph_objects as go
from plotly.offline import plot
spacing_1d = np.arange(50, dtype=float)
xx, yy = np.meshgrid(spacing_1d, spacing_1d)
x =xx.flatten()
y =yy.flatten()
ids = np.arange(50*50).reshape(50, 50)
con = np.stack(
[ids[:-1, :-1].flatten(), ids[:-1, 1:].flatten(),
ids[1:, 1:].flatten(), ids[1:, :-1].flatten(),
ids[:-1, :-1].flatten()]).T
data= []
for i in np.arange(len(con)):
c ='#%02x%02x%02x'%(np.random.randint(256),np.random.randint(256),np.random.randint(256))
name = "I can put custom data here:{}".format(i)
customdata = np.random.rand(3)
data.append(
go.Scatter(
x=x[con[i]],
y=y[con[i]],
mode='lines',
line=dict(color='rgb(150, 150, 150)', width=2),
fillcolor=c,
fill='toself',
customdata=customdata,
name= name,
)
)
fig = go.Figure(data)
fig.update_layout(width=1000, height=1000, showlegend=False,
template="none")
plot(fig, auto_open=True)
If anyone has any input, issues, feedback, etc, let me know.