temnick's picture
Initial content
f724cf3
//
// -----------------------------------------------------------------------------
// The proprietary software and information contained in this file is
// confidential and may only be used by an authorized person under a valid
// licensing agreement from Arm Limited or its affiliates.
//
// Copyright (C) 2025. Arm Limited or its affiliates. All rights reserved.
//
// This entire notice must be reproduced on all copies of this file and
// copies of this file may only be made by an authorized person under a valid
// licensing agreement from Arm Limited or its affiliates.
// -----------------------------------------------------------------------------
//
#ifndef NSS_COMMON
#define NSS_COMMON
#include "typedefs.h"
#define MAX_FP16 65504.HF
#define EPS 1e-7HF
// Activation Functions
// ──────────────────────────────────────────────────────────────────────────────────────────
half Sigmoid(half x)
{
return rcp(half(1.0) + exp(-x));
}
half2 Sigmoid(half2 x)
{
return rcp(half2(1.0) + exp(-x));
}
half3 Sigmoid(half3 x)
{
return rcp(half3(1.0) + exp(-x));
}
half4 Sigmoid(half4 x)
{
return rcp(half4(1.0) + exp(-x));
}
// Quantize/Dequantize
// ──────────────────────────────────────────────────────────────────────────────────────────
// all expect .x = scale, .y = zero point, quantize methods expect to receive: .x = rcp(scale)
half Dequantize(half i, half2 quant_params)
{
return (i - quant_params.y) * quant_params.x;
}
half2 Dequantize(half2 i, half2 quant_params)
{
return (i - quant_params.y) * quant_params.x;
}
half3 Dequantize(half3 i, half2 quant_params)
{
return (i - quant_params.y) * quant_params.x;
}
half4 Dequantize(half4 i, half2 quant_params)
{
return (i - quant_params.y) * quant_params.x;
}
int8_t Quantize(half f, half2 quant_params)
{
return int8_t(clamp(round(f * quant_params.x + quant_params.y), -128.HF, 127.HF));
}
int8_t2 Quantize(half2 f, half2 quant_params)
{
return int8_t2(clamp(round(f * quant_params.x + quant_params.y), -128.HF, 127.HF));
}
int8_t3 Quantize(half3 f, half2 quant_params)
{
return int8_t3(clamp(round(f * quant_params.x + quant_params.y), -128.HF, 127.HF));
}
int8_t4 Quantize(half4 f, half2 quant_params)
{
return int8_t4(clamp(round(f * quant_params.x + quant_params.y), -128.HF, 127.HF));
}
// Encode/Decode
// ──────────────────────────────────────────────────────────────────────────────────────────
// Note: both encode/decode methods are currently bound to 3x3 windows, they should be
// expandable in future if needed. The most likely to need this would be the jitter
// encoding, where 3x3 may not be enough for larger than 3x3 scale factors.
uint8_t EncodeNearestDepthCoord(int32_t2 o)
{
// o ∈ {-1, 0, 1}²
o = clamp(o, ivec2(-1), ivec2( 1));
return uint8_t((o.y + 1) << 2 | (o.x + 1)); // 0-15
}
int32_t2 DecodeNearestDepthCoord(int32_t code)
{
int32_t x = int32_t( code & 0x3) - 1; // bits 0-1
int32_t y = int32_t((code >> 2) & 0x3) - 1; // bits 2-3
return int32_t2(x, y);
}
// Image Operations
// ──────────────────────────────────────────────────────────────────────────────────────────
half Luminance(half3 rgb)
{
// ITU-R BT.709: `0.2126 * R + 0.7152 * G + 0.0722 * B`
return dot(rgb, half3(0.2126, 0.7152, 0.0722));
}
half3 Tonemap(half3 x)
{
// Karis tonemapper
// http://graphicrants.blogspot.com/2013/12/tone-mapping.html
x = max(x, half3(0.HF));
return x * rcp(half3(1.HF) + max(max(x.r, x.g), x.b));
}
half3 InverseTonemap(half3 x)
{
// Karis tonemapper inverse
// http://graphicrants.blogspot.com/2013/12/tone-mapping.html
x = clamp(x, half3(0.HF), Tonemap(half3(MAX_FP16)));
return x * rcp(half3(1.HF) - max(max(x.r, x.g), x.b));
}
half3 SafeColour(half3 x)
{
return clamp(x, half3(0.HF), half3(MAX_FP16));
}
#endif // NSS_COMMON