pytorch3d
pytorch3d copied to clipboard
Add Unity camera conversion
🚀 Feature
- Construct camera via
cameras_from_unity_world_to_camera_matrix(unity_world_to_camera_matrix) - Construct R, T via
convert_unity_transform(unity_camera_position, untiy_camera_rotation)
Motivation
As Unity is widely used in game development (e.g., VR/AR/MR), it should be constructive to add Unity camera conversion.
I have searched the issues and found a similar one #708 .
If it is officially supported, it will bring much more potential.
Pitch
I have fewer years of experience with Unity, but am new to pytorch3D.
cameras_from_unity_world_to_camera_matrix(unity_world_to_camera_matrix)
I dived a bit and figured out a solution to construct a Transform3d object from a Untiy world-to-camera matrix. And this Transform3d object can be world-to-view transform.
Here is an example.
def transform_unity_camera_view_matrix(matrix, device):
"""
This functions takes the input of world to camera matrix of unity camera.
It outputs a matrix in pytorch convention.
The view matrix is the also named world to camera matrix.
It can be got via `camera.worldToCameraMatrix` in Unity.
Args:
matrix: A numpy array of shape (4, 4).
device: The device for storing the implemented transformation.
.. code-block:: python
world_to_view_matrix = '''
0.50060 0.82827 0.25173 2.21195
-0.86366 0.49767 0.08003 -0.99088
0.05899 0.25747 -0.96448 -15.64826
0.00000 0.00000 0.00000 1.00000
'''
w2v = np.fromstring(world_to_view_matrix, count=16, sep=' ').reshape((4, 4))
w2v = transform_unity_camera_view_matrix(w2v, 'cpu')
print(w2v.get_matrix())
tensor([[[ 0.5006, 0.8637, 0.0590, 0.0000],
[-0.8283, 0.4977, -0.2575, 0.0000],
[-0.2517, 0.0800, 0.9645, 0.0000],
[-2.2120, -0.9909, 15.6483, 1.0000]]])
Returns:
Transform3d object.
"""
transform_tensor = torch.tensor([
[1, -1, 1, 1],
[-1, 1, -1, 1],
[-1, 1, -1, 1],
[-1, 1, -1, 1]],
dtype=torch.float32, device=device)
matrix = torch.Tensor(matrix.T).to(device)
matrix = matrix * transform_tensor
return Transform3d(matrix=matrix, device=device)
To use it in the camera.
w2v = transform_unity_camera_view_matrix(w2v, device)
class MyFoVPerspectiveCameras(FoVPerspectiveCameras):
def get_world_to_view_transform(self, **kwargs):
return w2v
convert_unity_transform(unity_camera_position, untiy_camera_rotation)
I have spent time finding a way to implement it but ended up not working. One reason is that the calculated rotation matrix is different.
Here is the testing code.
- Implementation of
euler_to_rotation_matrix
def euler_to_rotation_matrix(angles):
angles = np.radians(angles)
# ZYX(yaw, pitch, roll)
roll, pitch, yaw = angles
rotation_matrix = np.array([
[np.cos(yaw) * np.cos(pitch), np.cos(yaw) * np.sin(pitch) * np.sin(roll) - np.sin(yaw) * np.cos(roll),
np.cos(yaw) * np.sin(pitch) * np.cos(roll) + np.sin(yaw) * np.sin(roll)],
[np.sin(yaw) * np.cos(pitch), np.sin(yaw) * np.sin(pitch) * np.sin(roll) + np.cos(yaw) * np.cos(roll),
np.sin(yaw) * np.sin(pitch) * np.cos(roll) - np.cos(yaw) * np.sin(roll)],
[-np.sin(pitch), np.cos(pitch) * np.sin(roll), np.cos(pitch) * np.cos(roll)]
])
return rotation_matrix
- Testing using different packages/methods
"""
Unity
print(Matrix4x4.TRS(Vector3.zero, Quaternion.Euler(10.21f, -6.03f, 50.19f), Vector3.one))
0.62240 -0.77584 -0.10339 0.00000
0.75601 0.63011 -0.17726 0.00000
0.20267 0.03216 0.97872 0.00000
0.00000 0.00000 0.00000 1.00000
"""
rotation = np.array([10.21, -6.03, 50.19])
print('scipy.spatial.transform.Rotation')
print(Rotation.from_euler('xyz', rotation, degrees=True).as_matrix())
print('pyrr.Matrix44')
print(Matrix44.from_eulers(np.radians(rotation), dtype=float))
print('pytorch3d.transforms.euler_angles_to_matrix')
print(euler_angles_to_matrix(torch.tensor(np.radians(rotation)), "XYZ"))
print('euler_to_rotation_matrix')
print(euler_to_rotation_matrix(rotation))
scipy.spatial.transform.Rotation [[ 0.63670133 -0.76792931 0.06997141] [ 0.76392152 0.61580146 -0.19290535] [ 0.10504918 0.17627576 0.97871933]] pyrr.Matrix44 [[ 0.63670133 0.2023555 0.74408579 0. ] [-0.10504918 0.97871933 -0.17627576 0. ] [-0.76392152 0.03406941 0.64440918 0. ] [ 0. 0. 0. 1. ]] pytorch3d.transforms.euler_angles_to_matrix tensor([[ 0.6367, -0.7639, -0.1050], [ 0.7441, 0.6444, -0.1763], [ 0.2024, 0.0341, 0.9787]], dtype=torch.float64) euler_to_rotation_matrix [[ 0.63670133 -0.76792931 0.06997141] [ 0.76392152 0.61580146 -0.19290535] [ 0.10504918 0.17627576 0.97871933]]
Conclusion
- It is obvious that none is the same as Unity's output.
- scipy.spatial.transform.Rotation == euler_to_rotation_matrix (customized implementation)
- pytorch3d.transforms.euler_angles_to_matrix != scipy.spatial.transform.Rotation
- pytorch3d.transforms.euler_angles_to_matrix != pyrr.Matrix44
Since the rotation matrix is not the same as that in Unity, it is hard to get the working R and T from the Unity camera's position and rotation.
Thanks for your consideration!
After some experiments, I finally figured out how to get the correct R and T from the Unity camera position and rotation.
The reason why the rotation matrix was different is because of the order of calculation.
In Unity, the rotation multiplication order is
yxz.
- Here is the finished version for
convert_unity_transform(translation, rotation, device).
def convert_unity_transform(translation, rotation, device):
"""
This functions takes the input of Unity camera's translation and rotation.
It outputs R and T.
The `translation` and `rotation` can be got from Unity Inspector.
Args:
translation: A list of size 3.
rotation: A list of size 3, in the order of xyz and unit of degree.
device: The device for storing the implemented transformation.
.. code-block:: python
camera_translation = [1, 1.4, -14.71]
camera_rotation = [10.21, -6.03, 50.19]
R, T = convert_unity_transform(camera_translation, camera_rotation, 'cpu')
print(R)
print(T)
tensor([[[ 0.6224, 0.7758, 0.1034],
[-0.7560, 0.6301, -0.1773],
[-0.2027, 0.0322, 0.9787]]])
tensor([[-1.3004, 0.3668, 14.7485]])
Returns:
R and T.
"""
# Step 1: construct view (camera) to world transform
t = torch.eye(4, device=device)
t[:3, 3] = torch.tensor(np.array(translation), device=device)
# calculate transform separately to replicate rotation matrix in Unity
rx = RotateAxisAngle(rotation[0], 'X', degrees=True, device=device)
ry = RotateAxisAngle(rotation[1], 'Y', degrees=True, device=device)
rz = RotateAxisAngle(rotation[2], 'Z', degrees=True, device=device)
# this order matters
# `ry @ rx @ rz` in numpy syntax
r = rz.compose(rx).compose(ry)
r = r.get_matrix()
s = torch.eye(4, device=device)
s[2][2] = -1
v2w = t @ r[0, ...].T @ s
# Step 2: construct world to view transform
w2v = torch.inverse(v2w)
# Step 3: change to pytorch convention
# This step can be replaced by transform_unity_camera_view_matrix(w2v, device).
transform_tensor = torch.tensor([
[1, -1, 1, 1],
[-1, 1, -1, 1],
[-1, 1, -1, 1],
[-1, 1, -1, 1]],
dtype=torch.float32, device=device)
matrix = torch.Tensor(w2v.T).to(device)
matrix = matrix * transform_tensor
# Step 4: output result
R = matrix[:3, :3][None, ...]
T = matrix[3, :3][None, ...]
return R, T
- I even implemented a function to transform Unity GameObject's transform, which can be used to transform a mesh.
def transform_unity_object_transform(translation, rotation, device):
"""
This functions takes the input of one object's world translation and rotation in Unity.
It outputs a Transform3d in pytorch convention.
The `translation` and `rotation` can be got from Unity Inspector.
Args:
translation: A list of size 3.
rotation: A list of size 3, in the order of xyz and unit of degree.
device: The device for storing the implemented transformation.
.. code-block:: python
t = [3, 4, 7]
r = [-30, 10, 70]
res = transform_unity_object_translation_rotation(t, r, 'cpu')
print(res.get_matrix().numpy())
[[[ 0.255236 -0.8137977 0.52209944 0. ]
[ 0.95511216 0.29619804 -0.00523604 0. ]
[-0.15038374 0.5 0.8528685 0. ]
[-3. 4. 7. 1. ]]]
Returns:
Transform3d object.
"""
x, y, z = rotation
# calculate transform separately to replicate rotation matrix in Unity
rx = RotateAxisAngle(x, 'X', degrees=True, device=device)
ry = RotateAxisAngle(y, 'Y', degrees=True, device=device)
rz = RotateAxisAngle(z, 'Z', degrees=True, device=device)
# this order matters
# `ry @ rx @ rz` in numpy syntax
r = rz.compose(rx).compose(ry)
matrix = r.get_matrix()
t = torch.from_numpy(np.array(translation, dtype=np.float32))
matrix[..., 3, :3] = t
transform_tensor = torch.tensor([
[1, -1, -1, 1],
[-1, 1, 1, 1],
[-1, 1, 1, 1],
[-1, 1, 1, 1]],
dtype=torch.float32, device=device)
matrix = matrix * transform_tensor
matrix = Transform3d(matrix=matrix, device=device)
return matrix
def translate_mesh(mesh, t, r, device):
matrix = transform_unity_object_transform(t, r, device)
tverts = matrix.transform_points(mesh.verts_list()[0])
tmesh = Meshes(
verts=[tverts.to(device)],
faces=[mesh.faces_list()[0].to(device)],
textures=mesh.textures
)
return tmesh
Usage:
mesh_t = [3, 4, 7] mesh_r = [-30, 10, 70] mesh = translate_mesh(mesh, mesh_t, mesh_r, device=device)
I hope it helps and these functions can be officially supported. Thanks.