kram icon indicating copy to clipboard operation
kram copied to clipboard

MSL shader output comparsion of hlslparser vs. spriv-cross

Open alecazam opened this issue 1 year ago • 1 comments

This is a comparison of output from hlsparser rewriting HLSL2021 source into MSL vs. spriv-cross doing the same from spv. The processing paths are the following steps that are found in hlslparser/buildShaders.sh. When stepping through gpu capture of MSL, the hlslparser MSL is far easier to intuit about. Note that hlslparser wraps a class around all function calls in MSL to emulate the global intputs/outputs that HLSL is notorious for.

  • HLSL2021 -> hlslparser -> MSL 2.3
  • HLSL2021 -> hlslparser -> DXC -> Spirv-1.2 -> spriv-cross -> MSL 2.3

hlslparser

#include "ShaderMSL.h"
const static constant constexpr int NUM_LIGHTS = 3;
struct SamplePSNS {
    struct LightState {
        float3 position;
        float3 direction;
        float4 color;
        float4 falloff;
        float4x4 viewProj;
    };
    struct SceneConstantBuffer {
        float4x4 model;
        float4x4 viewProj;
        float4 ambientColor;
        bool sampleShadowMap;
        LightState lights[3];
    };
    struct InputPS {
        float4 position [[position]];
        float4 worldpos;
        float2 uv;
        float3 normal;
        float3 tangent;
    };
    thread depth2d<float>& shadowMap;
    thread texture2d<half>& diffuseMap;
    thread texture2d<half>& normalMap;
    thread sampler& sampleWrap;
    thread sampler& shadowMapSampler;
        constant SceneConstantBuffer & scene
;
    half3 CalcPerPixelNormal(float2 vTexcoord, half3 vVertNormal, half3 vVertTangent) {
        // Compute tangent frame.
        vVertNormal = normalize(vVertNormal);
        vVertTangent = normalize(vVertTangent);
        half3 vVertBinormal = normalize(cross(vVertTangent, vVertNormal));
        half3x3 mTangentSpaceToWorldSpace = half3x3(vVertTangent, vVertBinormal, vVertNormal);
        // Compute per-pixel normal.
        half3 vBumpNormal = (SampleH(normalMap, sampleWrap, vTexcoord)).xyz;
        vBumpNormal = ((2.h * vBumpNormal) - 1.h);
        return mul(vBumpNormal, mTangentSpaceToWorldSpace);
    };
    half4 CalcLightingColor(float3 vLightPos, float3 vLightDir, half4 vLightColor, float4 vFalloffs, float3 vPosWorld, half3 vPerPixelNormal) {
        float3 vLightToPixelUnNormalized = vPosWorld - vLightPos;
        // Dist falloff = 0 at vFalloffs.x, 1 at vFalloffs.x - vFalloffs.y
        float fDist = length(vLightToPixelUnNormalized);
        half fDistFalloff = (half)(saturate((vFalloffs.x - fDist) / vFalloffs.y));
        // Normalize from here on.
        half3 vLightToPixelNormalized = (half3)(normalize(vLightToPixelUnNormalized));
        // Angle falloff = 0 at vFalloffs.z, 1 at vFalloffs.z - vFalloffs.w
        half3 lightDir = (half3)(normalize(vLightDir));
        half fCosAngle = dot(vLightToPixelNormalized, lightDir);
        half fAngleFalloff = saturate((fCosAngle - (half)(vFalloffs.z)) / (half)(vFalloffs.w));
        // Diffuse contribution.
        half fNDotL = saturate((-dot(vLightToPixelNormalized, vPerPixelNormal)));
        return vLightColor * ((fNDotL * fDistFalloff) * fAngleFalloff);
    };
    half CalcUnshadowedAmountPCF2x2(int lightIndex, float4 vPosWorld, float4x4 viewProj) {
        // Compute pixel position in light space.
        float4 vLightSpacePos = vPosWorld;
        vLightSpacePos = mul(vLightSpacePos, viewProj);
        // need to reject before division (assuming revZ, infZ)
        if (vLightSpacePos.z > vLightSpacePos.w) {
            return (half)(1.);
        }
        // near/w for persp, z/1 for ortho
        vLightSpacePos.xyz /= vLightSpacePos.w;
        // Use HW filtering
        return (half)(SampleCmp(shadowMap, shadowMapSampler, vLightSpacePos));
    };
    float4 SamplePS(InputPS input) {
        half4 diffuseColor = SampleH(diffuseMap, sampleWrap, input.uv);
        half3 pixelNormal = CalcPerPixelNormal(input.uv, (half3)(input.normal), (half3)(input.tangent));
        half4 totalLight = (half4)(scene.ambientColor);
        for (int i = 0; i < NUM_LIGHTS; (i++)) {
            LightState light = scene.lights[i];
            half4 lightPass = CalcLightingColor(light.position, light.direction, (half4)(light.color), light.falloff, input.worldpos.xyz, pixelNormal);
            // only single light shadow map
            if (scene.sampleShadowMap && (i == 0)) {
                lightPass *= CalcUnshadowedAmountPCF2x2(i, input.worldpos, light.viewProj);
            }
            totalLight += lightPass;
        }
        return (float4)((diffuseColor * saturate(totalLight)));
    };

    SamplePSNS(
    thread depth2d<float> & shadowMap, 
    thread texture2d<half> & diffuseMap, 
    thread texture2d<half> & normalMap, 
    thread sampler & sampleWrap, 
    thread sampler & shadowMapSampler, 
    constant SceneConstantBuffer & scene)
     : shadowMap(shadowMap), 
    diffuseMap(diffuseMap), 
    normalMap(normalMap), 
    sampleWrap(sampleWrap), 
    shadowMapSampler(shadowMapSampler), 
    scene(scene) {}
};

fragment float4 SamplePS(
SamplePSNS::InputPS input [[stage_in]], 
depth2d<float> shadowMap [[texture(0)]], 
texture2d<half> diffuseMap [[texture(1)]], 
texture2d<half> normalMap [[texture(2)]], 
sampler sampleWrap [[sampler(0)]], 
sampler shadowMapSampler [[sampler(1)]], 
constant SamplePSNS::SceneConstantBuffer & scene [[buffer(0)]]) {
    SamplePSNS SamplePS(shadowMap, 
    diffuseMap, 
    normalMap, 
    sampleWrap, 
    shadowMapSampler, 
    scene);
    return SamplePS.SamplePS(input);
}

spriv-cross

#include <metal_stdlib>
#include <simd/simd.h>

using namespace metal;

struct LightState
{
    float3 position;
    float3 direction;
    float4 color;
    float4 falloff;
    float4x4 viewProj;
};

struct type_ConstantBuffer_SceneConstantBuffer
{
    float4x4 model;
    float4x4 viewProj;
    float4 ambientColor;
    uint sampleShadowMap;
    LightState lights[3];
};

struct SamplePS_out
{
    float4 out_var_SV_Target0 [[color(0)]];
};

struct SamplePS_in
{
    float4 in_var_TEXCOORD0 [[user(locn0)]];
    float2 in_var_TEXCOORD1 [[user(locn1)]];
    float3 in_var_NORMAL [[user(locn2)]];
    float3 in_var_TANGENT [[user(locn3)]];
};

fragment SamplePS_out SamplePS(SamplePS_in in [[stage_in]], constant type_ConstantBuffer_SceneConstantBuffer& scene [[buffer(0)]], depth2d<float> shadowMap [[texture(0)]], texture2d<float> diffuseMap [[texture(1)]], texture2d<float> normalMap [[texture(2)]], sampler sampleWrap [[sampler(0)]], sampler shadowMapSampler [[sampler(1)]])
{
    SamplePS_out out = {};
    float4 _73 = diffuseMap.sample(sampleWrap, in.in_var_TEXCOORD1, int2(0));
    half3 _77 = normalize(half3(in.in_var_NORMAL));
    half3 _78 = normalize(half3(in.in_var_TANGENT));
    float4 _84 = normalMap.sample(sampleWrap, in.in_var_TEXCOORD1, int2(0));
    half3 _89 = half3x3(_78, normalize(cross(_78, _77)), _77) * ((half4(_84).xyz * half(2.0)) - half3(half(1.0)));
    half4 _94;
    _94 = half4(scene.ambientColor);
    half4 _95;
    for (int _97 = 0; _97 < 3; _94 = _95, _97++)
    {
        float3 _111 = in.in_var_TEXCOORD0.xyz - scene.lights[_97].position;
        half3 _120 = half3(fast::normalize(_111));
        half4 _136 = ((half4(scene.lights[_97].color) * clamp(-dot(_120, _89), half(0.0), half(1.0))) * half(fast::clamp(scene.lights[_97].falloff.x - (length(_111) / scene.lights[_97].falloff.y), 0.0, 1.0))) * clamp(dot(_120, half3(fast::normalize(scene.lights[_97].direction))) - (half(scene.lights[_97].falloff.z) / half(scene.lights[_97].falloff.w)), half(0.0), half(1.0));
        bool _143;
        if (scene.sampleShadowMap != 0u)
        {
            _143 = _97 == 0;
        }
        else
        {
            _143 = false;
        }
        half4 _166;
        if (_143)
        {
            half _164;
            do
            {
                float4 _148 = in.in_var_TEXCOORD0 * scene.lights[_97].viewProj;
                float _150 = _148.w;
                if (_148.z > _150)
                {
                    _164 = half(1.0);
                    break;
                }
                float3 _156 = _148.xyz / float3(_150);
                _164 = half(shadowMap.sample_compare(shadowMapSampler, _156.xy, _156.z, int2(0)));
                break;
            } while(false);
            _166 = _136 * _164;
        }
        else
        {
            _166 = _136;
        }
        _95 = _94 + _166;
    }
    out.out_var_SV_Target0 = float4(half4(_73) * clamp(_94, half4(half(0.0)), half4(half(1.0))));
    return out;
}

alecazam avatar Mar 18 '23 01:03 alecazam