neural-super-sampling / scenario /0_pre_process.comp
temnick's picture
Initial content
f724cf3
//
// -----------------------------------------------------------------------------
// The proprietary software and information contained in this file is
// confidential and may only be used by an authorized person under a valid
// licensing agreement from Arm Limited or its affiliates.
//
// Copyright (C) 2025. Arm Limited or its affiliates. All rights reserved.
//
// This entire notice must be reproduced on all copies of this file and
// copies of this file may only be made by an authorized person under a valid
// licensing agreement from Arm Limited or its affiliates.
// -----------------------------------------------------------------------------
//
#version 460
#extension GL_EXT_shader_8bit_storage : require
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_float32 : require
#extension GL_GOOGLE_include_directive : enable
#extension GL_ARM_tensors : require
// includes
#include "typedefs.h"
#include "common.h"
// types
struct TensorElement
{
int8_t4 wh_rgb_col_r; // warped_history.rgb, jittered_colour.r
int8_t4 col_gb_dm_fback_r; // jittered_colour.gb, disocclusion mask, feedback.r
int8_t4 fback_gba_ld; // feedback.gba, luma derivative
};
// inputs
layout (set=0, binding=0) uniform mediump sampler2D _ColourTex; // 540p | R11G11B10 32bpp
layout (set=0, binding=1) uniform highp sampler2D _DepthTex; // 540p | R32_FLOAT 32bpp
layout (set=0, binding=2) uniform mediump sampler2D _MotionVectorTex; // 540p | RG_16 32bpp
layout (set=0, binding=3) uniform mediump sampler2D _HistoryTex; // 1080p | R11G11B10 32bpp
layout (set=0, binding=4) uniform lowp sampler2D _FeedbackTensor; // 1080p | R8G8B8A8_SNORM 32bpp | Tensor->Texture Alias (Linear)
layout (set=0, binding=5) uniform highp sampler2D _DepthTm1Tex; // 540p | R32_FLOAT 32bpp
layout (set=0, binding=6) uniform lowp sampler2D _LumaDerivTm1Tex; // 540p | R8G8_UNORM 16bpp
layout (set=0, binding=7) uniform lowp sampler2D _NearestDepthCoordTm1Tex; // 540p | R8_UNORM 8bpp
// outputs
layout (set=1, binding=0) uniform writeonly tensorARM<int8_t, 4> _PreprocessTensor; // 540p | 12ch 96bpp
layout (set=1, binding=1, rg8) uniform writeonly lowp image2D _PreProcessLumaDerivOut; // 540p | R8G8 16bpp
layout (set=1, binding=3, r8) uniform writeonly lowp image2D _NearestDepthCoordOut; // 540p | R8 8bpp
// push-constants
layout(push_constant, std430) uniform PushConstants {
// ─────────────── 16-byte aligned ───────────────
layout(offset = 0) float4 _DeviceToViewDepth; // 16 B
layout(offset = 16) float4 _JitterOffset; // 16 B (.xy = pixels, .zw = uvs)
layout(offset = 32) float4 _JitterOffsetTm1; // 16 B (.xy = pixels, .zw = uvs)
layout(offset = 48) float4 _ScaleFactor; // 16 B (.xy = scale, .zw = inv scale)
// ─────────────── 8-byte aligned ───────────────
layout(offset = 64) int32_t2 _OutputDims; // 8 B
layout(offset = 72) int32_t2 _InputDims; // 8 B
layout(offset = 80) float2 _InvOutputDims; // 8 B
layout(offset = 88) float2 _InvInputDims; // 8 B
layout(offset = 96) half4 _QuantParams; // 8 B (.xy SINT, .zw SNORM)
layout(offset = 104) half4 _MotionDisThreshPad; // 8 B (.xyzw = motion/disocclusion thresholds)
// ─────────────── 4-byte aligned ───────────────
layout(offset = 112) half2 _Exposure; // 4 B (.x = exposure, .y = 1/exp)
layout(offset = 116) half2 _HistoryPad; // 4 B
// ─────────────── padding to 16-byte struct size ────
layout(offset = 120) int32_t2 _Padding; // 8 B
// Total: **128 bytes**
};
// Convenience mapping for accessing push constants
#define _Scale _ScaleFactor.xy
#define _InvScale _ScaleFactor.zw
#define _Exposure _Exposure.x
#define _InvExposure _Exposure.y
#define _JitterOffsetPix _JitterOffset.xy
#define _JitterOffsetUv _JitterOffset.zw
#define _JitterOffsetTm1Pix _JitterOffsetTm1.xy
#define _JitterOffsetTm1Uv _JitterOffsetTm1.zw
#define _MotionWarpThresh _MotionDisThreshPad.x
#define _MotionDisThresh _MotionDisThreshPad.y
#define _DisocclusionScale _MotionDisThreshPad.z
#define _NotHistoryReset _HistoryPad.x
// Quantization Parameters
// inside: `./parameters.json`
// these values are embdedded inside the TOSA file and learnt during QAT
#ifndef _InputQuantParams
// inputs - x["SINT"]
#define _InputQuantParams _QuantParams.xy
#endif
#ifndef _FeedbackQuantParams
// outputs - activation_post_process_70["SNORM"]
#define _FeedbackQuantParams _QuantParams.zw
#endif
// constants
#ifdef INVERTED_DEPTH
#define MAX_DEPTH 0.f
#else
#define MAX_DEPTH 1.f
#endif
// methods
bool IsOnScreen(int32_t2 pos, int32_t2 size)
{
return all(lessThan(uint32_t2(pos), uint32_t2(size)));
}
half2 LoadMotion(int32_t2 pixel)
{
return half2(texelFetch(_MotionVectorTex, pixel, 0).rg);
}
half3 LoadColour(int32_t2 pixel)
{
return Tonemap(SafeColour(half3(texelFetch(_ColourTex, pixel, 0).rgb) * _Exposure));
}
int32_t2 LoadDepthNearestDepthOffsetTm1(int32_t2 pixel)
{
int32_t2 is_oob = int32_t2(IsOnScreen(pixel, _InputDims));
pixel = clamp(pixel, int32_t2(0), _InputDims - int32_t2(1));
half encNorm = half(texelFetch(_NearestDepthCoordTm1Tex, pixel, 0).r);
int32_t code = int32_t(encNorm * 255.0 + 0.5);
// 3. map back to {-1,0,1}Β²
return DecodeNearestDepthCoord(code) * is_oob;
}
void GatherReconstructedPreviousDepthRQuad(float2 fUV, inout float4 depthQuad)
{
int32_t2 offset = LoadDepthNearestDepthOffsetTm1(int32_t2(fUV * _InputDims));
float2 offset_uv = float2(offset) * _InvInputDims;
depthQuad = textureGather(_DepthTm1Tex, fUV + offset_uv, 0).wzxy;
}
half3 WarpHistory(float2 uv)
{
return Tonemap(SafeColour(half3(textureLod(_HistoryTex, uv, 0).rgb) * _Exposure));
}
half4 WarpFeedback(float2 uv)
{
return Dequantize(half4(textureLod(_FeedbackTensor, uv, 0)), _FeedbackQuantParams);
}
half2 WarpLumaDerivative(float2 uv)
{
return half2(textureLod(_LumaDerivTm1Tex, uv, 0).rg);
}
half2 CalculateLumaDerivative(float2 reproj_uv, half3 jittered_colour, half disocclusion_mask)
{
const half DIS_THRESH = 0.01HF;
const half DERIV_MIN = 0.05HF;
const half DERIV_MAX = 0.3HF;
const half DERIV_POW = 1.5HF;
const half DERIV_ALPHA = 0.1HF;
const half DERIV_MAX_R = rcp(DERIV_MAX);
const half DERIV_MAX_POW_R = rcp(pow(DERIV_MAX, DERIV_POW));
//--------------------------------------------------------------------
// 1. Fetch history (luma + derivative)
//--------------------------------------------------------------------
half2 h = WarpLumaDerivative(reproj_uv);
half luma_tm1 = h.y;
half derivative_tm1 = h.x;
//--------------------------------------------------------------------
// 2. Current luma & raw derivative
//--------------------------------------------------------------------
half luma_t = Luminance(jittered_colour);
half derivative_t = abs(luma_t - luma_tm1);
//--------------------------------------------------------------------
// 3. Soft-clip & normalize
//--------------------------------------------------------------------
// Clip to `DERIV_MAX` which is ~typical max value,
// allows for better precision allocation when normalized
half clipped = min(derivative_t, DERIV_MAX);
// Discard values less than `DERIV_MIN` to reduce ghosting
clipped *= step(DERIV_MIN, derivative_t);
// Normalize with soft-clip
// x^1.5 = x * sqrt(x) | NOTE: only works because `DERIV_POW=1.5`
half curved = clipped * sqrt(clipped) * DERIV_MAX_POW_R;
//--------------------------------------------------------------------
// 4. Temporal accumulation
//--------------------------------------------------------------------
// Accumulate the new derivative into the history.
// We apply an adaptive alpha scaling, to ensure that if a derivative converges to a high value
// it becomes more difficult to reset that value, this provides temporally stable convergence
half alpha_scale = mix(DERIV_ALPHA,
DERIV_ALPHA * 0.1HF,
clamp(derivative_tm1, 0.HF, DERIV_MAX) * DERIV_MAX_R);
half derivative = mix(derivative_tm1, curved, alpha_scale);
//--------------------------------------------------------------------
// 5. Remove disoccluded pixels
//--------------------------------------------------------------------
derivative *= step(disocclusion_mask, DIS_THRESH);
// .x -> derivative for current frame, .y -> luma of current frame
return half2(derivative, luma_t);
}
void FindNearestDepth(int32_t2 iPxPos, int32_t2 iPxSize, out float fNearestDepth, out int32_t2 fNearestDepthOffset)
{
/*
Closely based on:
https://github.com/arm/accuracy-super-resolution-generic-library/blob/38697a58a6e7818ec9d28774bc073f537abb9178/
include/gpu/fsr2/ffxm_fsr2_reconstruct_dilated_velocity_and_previous_depth.h#L59
*/
int32_t iSampleIndex = 0;
const int32_t iSampleCount = 9;
// x, y
const int32_t2 iSampleOffsets[iSampleCount] = {
int32_t2(+0, +0).yx,
int32_t2(+1, +0).yx,
int32_t2(+0, +1).yx,
int32_t2(+0, -1).yx,
int32_t2(-1, +0).yx,
int32_t2(-1, +1).yx,
int32_t2(+1, +1).yx,
int32_t2(-1, -1).yx,
int32_t2(+1, -1).yx,
};
// pull out the depth loads to allow SC to batch them
float depth[9];
depth[0] = float(texelFetchOffset(_DepthTex, iPxPos, 0, int32_t2(+0, +0).yx).r);
depth[1] = float(texelFetchOffset(_DepthTex, iPxPos, 0, int32_t2(+1, +0).yx).r);
depth[2] = float(texelFetchOffset(_DepthTex, iPxPos, 0, int32_t2(+0, +1).yx).r);
depth[3] = float(texelFetchOffset(_DepthTex, iPxPos, 0, int32_t2(+0, -1).yx).r);
depth[4] = float(texelFetchOffset(_DepthTex, iPxPos, 0, int32_t2(-1, +0).yx).r);
depth[5] = float(texelFetchOffset(_DepthTex, iPxPos, 0, int32_t2(-1, +1).yx).r);
depth[6] = float(texelFetchOffset(_DepthTex, iPxPos, 0, int32_t2(+1, +1).yx).r);
depth[7] = float(texelFetchOffset(_DepthTex, iPxPos, 0, int32_t2(-1, -1).yx).r);
depth[8] = float(texelFetchOffset(_DepthTex, iPxPos, 0, int32_t2(+1, -1).yx).r);
// find closest depth
fNearestDepth = depth[0];
fNearestDepthOffset = iSampleOffsets[0];
#pragma unroll
for (iSampleIndex = 1; iSampleIndex < iSampleCount; ++iSampleIndex) {
int32_t2 iPos = iPxPos + iSampleOffsets[iSampleIndex];
if (IsOnScreen(iPos, iPxSize)) {
float fNdDepth = depth[iSampleIndex];
#ifdef INVERTED_DEPTH
if (fNdDepth > fNearestDepth) {
#else
if (fNdDepth < fNearestDepth) {
#endif
fNearestDepth = fNdDepth;
fNearestDepthOffset = iSampleOffsets[iSampleIndex];
}
}
}
}
int32_t2 RenderSize()
{
return int32_t2(_InputDims);
}
float2 ComputeNdc(float2 fPxPos, int32_t2 iSize)
{
/*
Closely based on:
https://github.com/arm/accuracy-super-resolution-generic-library/blob/
38697a58a6e7818ec9d28774bc073f537abb9178/include/gpu/fsr2/ffxm_fsr2_common.h#L457
*/
return fPxPos.yx / float2(iSize.yx) * float2(2.0f, -2.0f) + float2(-1.0f, 1.0f);
}
float GetViewSpaceDepth(float fDeviceDepth)
{
/*
Closely based on:
https://github.com/arm/accuracy-super-resolution-generic-library/blob/
38697a58a6e7818ec9d28774bc073f537abb9178/include/gpu/fsr2/ffxm_fsr2_common.h#L462
`fDeviceToViewDepth` / `_DeviceToViewDepth` details found in:
https://github.com/arm/accuracy-super-resolution-generic-library/blob/
0501f490bd9946a2e1806b5363d7ab8a9a6a5e0a/src/components/fsr2/ffxm_fsr2.cpp#L829
*/
const float4 fDeviceToViewDepth = _DeviceToViewDepth;
return (fDeviceToViewDepth[1] / (fDeviceDepth - fDeviceToViewDepth[0]));
}
float3 GetViewSpacePosition(int32_t2 iViewportPos, int32_t2 iViewportSize, float fDeviceDepth)
{
/*
Closely based on:
https://github.com/arm/accuracy-super-resolution-generic-library/blob/
38697a58a6e7818ec9d28774bc073f537abb9178/include/gpu/fsr2/ffxm_fsr2_common.h#L475
*/
const float4 fDeviceToViewDepth = _DeviceToViewDepth;
const float Z = GetViewSpaceDepth(fDeviceDepth);
const float2 fNdcPos = ComputeNdc(iViewportPos, iViewportSize);
const float X = fDeviceToViewDepth[2] * fNdcPos.x * Z;
const float Y = fDeviceToViewDepth[3] * fNdcPos.y * Z;
return float3(X, Y, Z);
}
struct BilinearSamplingData
{
int32_t2 iOffsets[4];
float fWeights[4];
int32_t2 iBasePos;
float2 fQuadCenterUv;
};
BilinearSamplingData GetBilinearSamplingData(float2 fUv, int32_t2 iSize)
{
/*
Closely based on:
https://github.com/arm/accuracy-super-resolution-generic-library/blob/
38697a58a6e7818ec9d28774bc073f537abb9178/include/gpu/fsr2/ffxm_fsr2_common.h#L548
*/
BilinearSamplingData data;
float2 fPxSample = (fUv * iSize) - float2(0.5f, 0.5f);
data.iBasePos = int32_t2(floor(fPxSample));
data.fQuadCenterUv = (fPxSample + 0.5f) / float2(iSize);
float2 fPxFrac = fract(fPxSample);
data.iOffsets[0] = int32_t2(0, 0);
data.iOffsets[2] = int32_t2(1, 0);
data.iOffsets[1] = int32_t2(0, 1);
data.iOffsets[3] = int32_t2(1, 1);
data.fWeights[0] = (1.f - fPxFrac.x) * (1.f - fPxFrac.y);
data.fWeights[1] = (fPxFrac.x) * (1.f - fPxFrac.y);
data.fWeights[2] = (1.f - fPxFrac.x) * (fPxFrac.y);
data.fWeights[3] = (fPxFrac.x) * (fPxFrac.y);
return data;
}
float ComputeDepthClip(float2 fUvSample, float fCurrentDepthSample)
{
/*
Closely based on:
https://github.com/arm/accuracy-super-resolution-generic-library/blob/
38697a58a6e7818ec9d28774bc073f537abb9178/include/gpu/fsr2/ffxm_fsr2_depth_clip.h#L36
*/
const float fReconstructedDepthBilinearWeightThreshold = 0.1f;
float fCurrentDepthViewSpace = GetViewSpaceDepth(fCurrentDepthSample);
BilinearSamplingData bilinearInfo = GetBilinearSamplingData(fUvSample, RenderSize());
float fDepth = 0.0f;
float fWeightSum = 0.0f;
float4 fPrevDepthSamples;
GatherReconstructedPreviousDepthRQuad(bilinearInfo.fQuadCenterUv, fPrevDepthSamples);
for (int32_t iSampleIndex = 0; iSampleIndex < 4; iSampleIndex++)
{
const int32_t2 iOffset = bilinearInfo.iOffsets[iSampleIndex];
const int32_t2 iSamplePos = bilinearInfo.iBasePos + iOffset;
const float fWeight = bilinearInfo.fWeights[iSampleIndex];
const bool onscreen = IsOnScreen(iSamplePos, RenderSize());
fWeightSum += onscreen ? 0.f : fWeight;
if (onscreen)
{
if (fWeight > fReconstructedDepthBilinearWeightThreshold)
{
const float fPrevDepthSample = fPrevDepthSamples[iSampleIndex];
const float fPrevNearestDepthViewSpace = GetViewSpaceDepth(fPrevDepthSample);
const float fDepthDiff = fCurrentDepthViewSpace - fPrevNearestDepthViewSpace;
if (fDepthDiff > 0.0f) {
#ifdef INVERTED_DEPTH
const float fPlaneDepth = min(fPrevDepthSample, fCurrentDepthSample);
#else
const float fPlaneDepth = max(fPrevDepthSample, fCurrentDepthSample);
#endif
const float3 fCenter = GetViewSpacePosition(int32_t2(RenderSize() * 0.5f), RenderSize(), fPlaneDepth);
const float3 fCorner = GetViewSpacePosition(int32_t2(0, 0), RenderSize(), fPlaneDepth);
const float fHalfViewportWidth = length(float2(RenderSize()));
const float fDepthThreshold = max(fCurrentDepthViewSpace, fPrevNearestDepthViewSpace);
const float Ksep = 1.37e-05f;
const float Kfov = length(fCorner) / length(fCenter);
const float fRequiredDepthSeparation = Ksep * Kfov * fHalfViewportWidth * fDepthThreshold;
const float fResolutionFactor = saturate(length(float2(RenderSize())) / length(float2(1920.0f, 1080.0f)));
const float fPower = lerp(1.0f, 3.0f, fResolutionFactor);
fDepth += pow(saturate(float(fRequiredDepthSeparation / fDepthDiff)), fPower) * fWeight;
fWeightSum += fWeight;
}
}
}
}
return (fWeightSum > 0) ? saturate(1.0f - fDepth / fWeightSum) : 0.0f;
}
void WriteLumaDerivative(int32_t2 pixel, half2 derivative)
{
imageStore(_PreProcessLumaDerivOut, pixel, half4(derivative, half2(0.f, 1.f)));
}
void WriteNearestDepthOffset(int32_t2 pixel, uint8_t offset)
{
half enc_norm = half(offset) / 255.HF;
imageStore(_NearestDepthCoordOut, pixel, half4(enc_norm, 0.HF, 0.HF, 1.HF));
}
void WriteToTensor(int32_t2 outputPixel, half3 input_colour, half3 history, half disocclusion_mask, half luma_derivative, half4 temporal_feedback)
{
TensorElement te;
te.wh_rgb_col_r = Quantize(half4(history.rgb, input_colour.r), _InputQuantParams);
te.col_gb_dm_fback_r = Quantize(half4(input_colour.gb, disocclusion_mask, temporal_feedback.r), _InputQuantParams);
te.fback_gba_ld = Quantize(half4(temporal_feedback.gba, luma_derivative), _InputQuantParams);
int8_t t0[12] =
{
te.wh_rgb_col_r.x,
te.wh_rgb_col_r.y,
te.wh_rgb_col_r.z,
te.wh_rgb_col_r.w,
te.col_gb_dm_fback_r.x,
te.col_gb_dm_fback_r.y,
te.col_gb_dm_fback_r.z,
te.col_gb_dm_fback_r.w,
te.fback_gba_ld.x,
te.fback_gba_ld.y,
te.fback_gba_ld.z,
te.fback_gba_ld.w
};
tensorWriteARM(_PreprocessTensor, uint[](0, outputPixel.y, outputPixel.x, 0), t0);
}
// entry-point
layout(local_size_x = 16, local_size_y = 16) in;
void main()
{
int32_t2 input_pixel = int32_t2(gl_GlobalInvocationID.xy);
if (any(greaterThanEqual(input_pixel, _InputDims))) return;
float2 uv = (float2(input_pixel) + 0.5f) * _InvInputDims;
//-------------------------------------------------------------------------
// 1) Dilate depth, find nearest pixel coordinate
//-------------------------------------------------------------------------
float depth_dilated = float(0.f);
int32_t2 nearest_pixel_offset = int32_t2(0);
FindNearestDepth(input_pixel, RenderSize(), depth_dilated, nearest_pixel_offset);
//-------------------------------------------------------------------------
// 2) Load motion vectors
//-------------------------------------------------------------------------
half2 motion = LoadMotion(input_pixel + nearest_pixel_offset);
// Suppress very small motion - no value in resampling here
half2 motion_pix = motion * half2(RenderSize());
motion *= half(dot(motion_pix, motion_pix) > _MotionWarpThresh);
// Calculate sample position(s) for everything in `tm1` frame
float2 reproj_uv = uv - float2(motion);
float2 unjitter_tm1_uv = reproj_uv - _JitterOffsetTm1Uv;
//-------------------------------------------------------------------------
// 3) Calculate depth-based disocclusion mask
//-------------------------------------------------------------------------
half disocclusion_mask = half(ComputeDepthClip(unjitter_tm1_uv, depth_dilated));
// Scale disocclusion mask on static frames to let network know this is happening under
// static conditions, reduces jitter differences across frames causing false flags
half dm_scale = dot(motion_pix, motion_pix) > _MotionDisThresh ? half(1.0f) : _DisocclusionScale;
disocclusion_mask = disocclusion_mask * dm_scale;
//-------------------------------------------------------------------------
// 4) Downsample + warp history buffer
//-------------------------------------------------------------------------
half3 warped_history = WarpHistory(reproj_uv);
//-------------------------------------------------------------------------
// 5) Read current low-res / jittered / aliased colour
//-------------------------------------------------------------------------
half3 jittered_colour = LoadColour(input_pixel);
//-------------------------------------------------------------------------
// 6) Calculate derivative of `luma`
// helps identifying high-frequency flicker due to jitter
//-------------------------------------------------------------------------
half2 luma_derivative = CalculateLumaDerivative(reproj_uv, jittered_colour, disocclusion_mask);
//-------------------------------------------------------------------------
// 7) Warp temporal feedback
//-------------------------------------------------------------------------
half4 temporal_feedback = WarpFeedback(reproj_uv);
//-------------------------------------------------------------------------
// 8) Convert dilated depth coord to a position offset
//-------------------------------------------------------------------------
uint8_t enc_depth_offset = EncodeNearestDepthCoord(nearest_pixel_offset);
//-------------------------------------------------------------------------
// 9) Write Outputs
//-------------------------------------------------------------------------
// Consumed by NE
WriteToTensor(
input_pixel,
jittered_colour, // 3ch
warped_history, // 3ch
disocclusion_mask, // 1ch
luma_derivative.x, // 1ch
temporal_feedback // 4ch
); // total: 12ch
// Consumed by post process and frame t+1
WriteNearestDepthOffset(input_pixel, enc_depth_offset);
// Consumed at frame t+1
WriteLumaDerivative(input_pixel, luma_derivative);
}