neural-super-sampling / scenario /2_post_process.comp
temnick's picture
Initial content
f724cf3
//
// -----------------------------------------------------------------------------
// The proprietary software and information contained in this file is
// confidential and may only be used by an authorized person under a valid
// licensing agreement from Arm Limited or its affiliates.
//
// Copyright (C) 2025. Arm Limited or its affiliates. All rights reserved.
//
// This entire notice must be reproduced on all copies of this file and
// copies of this file may only be made by an authorized person under a valid
// licensing agreement from Arm Limited or its affiliates.
// -----------------------------------------------------------------------------
//
#version 460
#extension GL_EXT_shader_8bit_storage : require
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_float32 : require
#extension GL_GOOGLE_include_directive : enable
// defines
#define SCALE_1_0X 0
#define SCALE_1_3X 1
#define SCALE_1_5X 2
#define SCALE_2_0X 3
// settings
#define HISTORY_CATMULL
#define SCALE_MODE SCALE_2_0X
// includes
#include "typedefs.h"
#include "common.h"
#include "kernel_lut.h"
// inputs
layout (set=0, binding=0) uniform mediump sampler2D _ColourTex; // 540p | R11G11B10 32bpp
layout (set=0, binding=1) uniform mediump sampler2D _MotionVectorTex; // 540p | RG16_FLOAT 32bpp
layout (set=0, binding=2) uniform mediump sampler2D _HistoryTex; // 1080p | R11G11B10 32bpp
layout (set=0, binding=3) uniform lowp sampler2D _K0Tensor; // 540p | R8G8B8A8_SNORM 32bpp | Tensor->Texture Alias (Linear)
layout (set=0, binding=4) uniform lowp sampler2D _K1Tensor; // 540p | R8G8B8A8_SNORM 32bpp | Tensor->Texture Alias (Linear)
layout (set=0, binding=5) uniform lowp sampler2D _K2Tensor; // 540p | R8G8B8A8_SNORM 32bpp | Tensor->Texture Alias (Linear)
layout (set=0, binding=6) uniform lowp sampler2D _K3Tensor; // 540p | R8G8B8A8_SNORM 32bpp | Tensor->Texture Alias (Linear)
layout (set=0, binding=7) uniform lowp sampler2D _TemporalTensor; // 540p | R8G8B8A8_SNORM 32bpp | Tensor->Texture Alias (Linear)
layout (set=0, binding=8) uniform lowp sampler2D _NearestDepthCoordTex; // 540p | R8_UNORM 8bpp
// outputs
layout (set=1, binding=0, r11f_g11f_b10f) uniform writeonly mediump image2D _UpsampledColourOut; // 1080p | R11G11B10 32bpp
// push-constants
layout(push_constant, std430) uniform PushConstants {
// ─────────────── 8-byte aligned ───────────────
layout(offset = 0) int32_t2 _OutputDims; // 8 B
layout(offset = 8) int32_t2 _InputDims; // 8 B
layout(offset = 16) float2 _InvOutputDims; // 8 B
layout(offset = 24) float2 _InvInputDims; // 8 B
layout(offset = 32) float2 _Scale; // 8 B
layout(offset = 40) float2 _InvScale; // 8 B
// ─────────────── 4-byte aligned ───────────────
layout(offset = 48) int16_t2 _IndexModulo; // 4 B
layout(offset = 52) half2 _QuantParams; // 4 B
layout(offset = 56) int16_t2 _LutOffset; // 4 B
layout(offset = 60) half2 _ExposurePair; // 4 B
layout(offset = 64) half2 _HistoryPad; // 4 B
layout(offset = 68) half2 _MotionThreshPad; // 4 B (.x = motion, .y = unused)
layout(offset = 72) int32_t _Padding0; // 4 B (explicit pad for alignment)
// Total: **76 bytes**
};
// Convenience mapping for accessing push constants
#define _Exposure _ExposurePair.x
#define _InvExposure _ExposurePair.y
#define _NotHistoryReset _HistoryPad.x
#define _MotionThresh _MotionThreshPad.x
// Quantization Parameters
// inside: `./parameters.json`
// these values are embdedded inside the TOSA file and learnt during QAT
#ifndef _K0QuantParams
// outputs - activation_post_process_45["SNORM"]
#define _K0QuantParams _QuantParams.xy
#endif
#ifndef _K1QuantParams
// outputs - activation_post_process_50["SNORM"]
#define _K1QuantParams _QuantParams.xy
#endif
#ifndef _K2QuantParams
// outputs - activation_post_process_55["SNORM"]
#define _K2QuantParams _QuantParams.xy
#endif
#ifndef _K3QuantParams
// outputs - activation_post_process_60["SNORM"]
#define _K3QuantParams _QuantParams.xy
#endif
#ifndef _TemporalQuantParams
// outputs - activation_post_process_65["SNORM"]
#define _TemporalQuantParams _QuantParams.xy
#endif
// methods
half2 LoadMotion(int32_t2 pixel)
{
return half2(texelFetch(_MotionVectorTex, pixel, 0).rg);
}
half3 LoadHistory(float2 uv)
{
return half3(textureLod(_HistoryTex, uv, 0).rgb);
}
half3 LoadHistoryCatmull(float2 uv)
{
//------------------------------------------------------------------------------------
// 1) Compute Catmull–Rom weights
//------------------------------------------------------------------------------------
float2 scaledUV = uv * _OutputDims;
float2 baseFloor = floor(scaledUV - 0.5) + 0.5;
half2 f = half2(scaledUV - baseFloor);
half2 f2 = f * f;
half2 f3 = f2 * f;
// Catmull–Rom basis
half2 w0 = f2 - 0.5HF * (f3 + f);
half2 w1 = 1.5HF * f3 - 2.5HF * f2 + 1.0HF;
half2 w3 = 0.5HF * (f3 - f2);
half2 w2 = (1.0HF - w0) - w1 - w3; // = 1 - (w0 + w1 + w3)
// Combine w1 and w2 for center axis
half2 w12 = w1 + w2;
half wx0 = w0.x, wy0 = w0.y;
half wx1 = w12.x, wy1 = w12.y;
half wx2 = w3.x, wy2 = w3.y;
// Final weights for the cross sample layout
half wUp = wx1 * wy0; // center in X, up in Y
half wDown = wx1 * wy2; // center in X, down in Y
half wLeft = wx0 * wy1; // left in X, center in Y
half wRight = wx2 * wy1; // right in X, center in Y
half wCenter = wx1 * wy1; // center in X, center in Y
// Fractional offsets for the center
half dx = w2.x / wx1;
half dy = w2.y / wy1;
//------------------------------------------------------------------------------------
// 2) Gather the 5 taps
//------------------------------------------------------------------------------------
half4 left = half4(LoadHistory((baseFloor + float2(-1.0, dy)) * _InvOutputDims ), 1.HF);
half4 up = half4(LoadHistory((baseFloor + float2(dx, -1.0)) * _InvOutputDims ), 1.HF);
half4 center = half4(LoadHistory((baseFloor + float2(dx, dy)) * _InvOutputDims ), 1.HF);
half4 right = half4(LoadHistory((baseFloor + float2(2.0, dy)) * _InvOutputDims ), 1.HF);
half4 down = half4(LoadHistory((baseFloor + float2(dx, 2.0)) * _InvOutputDims ), 1.HF);
//------------------------------------------------------------------------------------
// 3) Accumulate and track min/max
//------------------------------------------------------------------------------------
half4 accum = up * wUp +
left * wLeft +
center* wCenter +
right * wRight +
down * wDown;
half3 cmin3 = min(up.rgb,
min(left.rgb,
min(center.rgb,
min(right.rgb, down.rgb))));
half3 cmax3 = max(up.rgb,
max(left.rgb,
max(center.rgb,
max(right.rgb, down.rgb))));
//------------------------------------------------------------------------------------
// 4) Final color
//------------------------------------------------------------------------------------
half3 color = accum.rgb * rcp(accum.w);
// dering in the case where we have negative values, we don't do this all the time
// as it can impose unnecessary blurring on the output
return any(lessThan(color, half3(0.HF)))
? clamp(color, cmin3, cmax3)
: color;
}
int32_t2 LoadNearestDepthOffset(int32_t2 pixel)
{
half encNorm = half(texelFetch(_NearestDepthCoordTex, pixel, 0).r);
int32_t code = int32_t(encNorm * 255.0 + 0.5);
// 3. map back to {-1,0,1}Β²
return DecodeNearestDepthCoord(code);
}
half3 LoadWarpedHistory(float2 uv, int32_t2 input_pixel, out half onscreen)
{
// Dilate motion vectors with previously calculated nearest depth coordinate
int32_t2 nearest_offset = LoadNearestDepthOffset(input_pixel);
half2 motion = LoadMotion(input_pixel + nearest_offset);
// Suppress very small motion - no need to resample
half2 motion_pix = motion * half2(_OutputDims);
motion *= half(dot(motion_pix, motion_pix) > _MotionThresh);
// UV coordinates in previous frame to resample history
float2 reproj_uv = uv - float2(motion);
// Mask to flag whether the motion vector is resampling from valid location onscreen
onscreen = half(
all(greaterThanEqual(reproj_uv, float2(0.0))) &&
all(lessThan(reproj_uv, float2(1.0)))
);
#ifdef HISTORY_CATMULL
half3 warped_history = LoadHistoryCatmull(reproj_uv);
#else
half3 warped_history = LoadHistory(reproj_uv);
#endif
return SafeColour(warped_history * _Exposure);
}
#if SCALE_MODE == SCALE_2_0X
/*
Optimised special case pattern for applying 4x4 kernel to
sparse jitter-aware 2x2 upsampled image
*/
half4 LoadKPNWeight(float2 uv, int16_t lut_idx)
{
// Load 4 kernel slices (each with 4 taps)
half4 k0 = Dequantize(half4(textureLod(_K0Tensor, uv, 0)), _K0QuantParams);
half4 k1 = Dequantize(half4(textureLod(_K1Tensor, uv, 0)), _K1QuantParams);
half4 k2 = Dequantize(half4(textureLod(_K2Tensor, uv, 0)), _K2QuantParams);
half4 k3 = Dequantize(half4(textureLod(_K3Tensor, uv, 0)), _K3QuantParams);
// Precomputed swizzle patterns for KernelTile
half4 p0 = half4(k0.x, k2.x, k0.z, k2.z);
half4 p1 = half4(k1.x, k3.x, k1.z, k3.z);
half4 p2 = half4(k0.y, k2.y, k0.w, k2.w);
half4 p3 = half4(k1.y, k3.y, k1.w, k3.w);
// Return the correct pattern for this tile
return (lut_idx == 0) ? p0 :
(lut_idx == 1) ? p1 :
(lut_idx == 2) ? p2 :
p3;
}
half3 LoadAndFilterColour(int32_t2 output_pixel, float2 uv, out half4 col_to_accum)
{
//-------------------------------------------------------------------
// 1. Compute indexes, load correct pattern from LUT for given thread
//-------------------------------------------------------------------
float2 out_tex = float2(output_pixel) + 0.5f;
// Compute the LUT index for this pixel
int16_t2 tiled_idx = (int16_t2(output_pixel) + _LutOffset) % int16_t2(_IndexModulo);
int16_t lut_idx = tiled_idx.y * int16_t(_IndexModulo) + tiled_idx.x;
KernelTile lut = kernelLUT[lut_idx];
//------------------------------------------------------------------
// 2. Apply KPN
//------------------------------------------------------------------
// Dequantize the kernel weights
half4 kpn_weights = clamp(LoadKPNWeight(uv, lut_idx), half4(EPS), half4(1.HF));
// Calculate tap locations
int16_t4 tap_x = clamp(int16_t4(floor((float4(out_tex.x) + float4(lut.dx)) * _InvScale.x)), int16_t4(0), int16_t4(_InputDims.x - 1));
int16_t4 tap_y = clamp(int16_t4(floor((float4(out_tex.y) + float4(lut.dy)) * _InvScale.y)), int16_t4(0), int16_t4(_InputDims.y - 1));
// Gather taps
f16mat4x4 interm;
interm[0] = half4(SafeColour(half3(texelFetch(_ColourTex, int16_t2(tap_x[0], tap_y[0]), 0).rgb) * half3(_Exposure)), 1.HF);
interm[1] = half4(SafeColour(half3(texelFetch(_ColourTex, int16_t2(tap_x[1], tap_y[1]), 0).rgb) * half3(_Exposure)), 1.HF);
interm[2] = half4(SafeColour(half3(texelFetch(_ColourTex, int16_t2(tap_x[2], tap_y[2]), 0).rgb) * half3(_Exposure)), 1.HF);
interm[3] = half4(SafeColour(half3(texelFetch(_ColourTex, int16_t2(tap_x[3], tap_y[3]), 0).rgb) * half3(_Exposure)), 1.HF);
// Special case: grab the accumulation pixel, when it corresponds to current thread
half match = half(lut.dx[CENTER_TAP] == 0 && lut.dy[CENTER_TAP] == 0);
col_to_accum = interm[CENTER_TAP] * match;
// Apply filter
half4 out_colour = interm * kpn_weights;
return half3(out_colour.rgb * rcp(out_colour.w));
}
#else
#error "Unsupported SCALE_MODE"
#endif // SCALE_MODE == SCALE_2_0X
void LoadTemporalParameters(float2 uv, out half theta, out half alpha)
{
half2 tp = Dequantize(half2(textureLod(_TemporalTensor, uv, 0).xy), _TemporalQuantParams);
theta = tp.x * _NotHistoryReset; // {0 <= x <= 1}
alpha = tp.y * 0.35HF + 0.05HF; // { 0.05 <= x <= 0.4}
}
void WriteUpsampledColour(int32_t2 pixel, half3 colour)
{
half3 to_write = SafeColour(colour);
// Write with alpha = 1.0
imageStore(_UpsampledColourOut, pixel, half4(to_write, 1.0));
}
// entry-point
layout(local_size_x = 16, local_size_y = 16) in;
void main()
{
int32_t2 output_pixel = int32_t2(gl_GlobalInvocationID.xy);
if (any(greaterThanEqual(output_pixel, _OutputDims))) return;
float2 uv = (float2(output_pixel) + 0.5) * _InvOutputDims;
int32_t2 input_pixel = int32_t2(uv * _InputDims);
//-------------------------------------------------------------------------
// 1) Warp history
//-------------------------------------------------------------------------
half onscreen;
half3 history = LoadWarpedHistory(uv, input_pixel, onscreen);
//-------------------------------------------------------------------------
// 2) KPN filter β†’ col
//-------------------------------------------------------------------------
half4 col_to_accum;
half3 colour = LoadAndFilterColour(output_pixel, uv, col_to_accum);
// -------------------------------------------------------------------------
// 3) Load temporal parameters
//-------------------------------------------------------------------------
half theta, alpha;
LoadTemporalParameters(uv, theta, alpha);
//-------------------------------------------------------------------------
// 3) Rectify history, force reset when offscreen
//-------------------------------------------------------------------------
half3 rectified = lerp(colour, history, theta * onscreen);
//-------------------------------------------------------------------------
// 3) Accumulate new sample
//-------------------------------------------------------------------------
half3 accumulated = lerp(Tonemap(rectified), Tonemap(col_to_accum.rgb), alpha * col_to_accum.a);
//-------------------------------------------------------------------------
// 4) Inverse tonemap + exposure and write output
//-------------------------------------------------------------------------
half3 out_linear = InverseTonemap(accumulated) * _InvExposure;
WriteUpsampledColour(output_pixel, out_linear);
}