近似双曲正切
Approximating Hyperbolic Tangent

原始链接: https://jtomschroeder.com/blog/approximating-tanh/

## 快速 `tanh` 近似:总结 双曲正切函数 (`tanh`) 在神经网络和音频处理中至关重要,但其计算成本可能很高。 本综述探讨了各种近似 `tanh` 以提高速度的方法。 **传统方法** 包括 **泰勒级数**,它具有简单性但准确性有限,以及 **帕德逼近**,它通过多项式之比提高了准确性,但需要更多运算。 **样条曲线** 提供了另一种选择,将函数划分为分段多项式,优先考虑速度而非极致精度。 更高级的技术利用底层的 **IEEE-754 浮点表示法**。 **K-TanH** 利用整数运算和查找表进行高效近似,特别适合硬件实现。 **Schraudolph 方法** 巧妙地操纵浮点数的位表示,用最少的算术运算来近似 `tanh`——这是一种类似于快速反平方根黑客的技术。 后来的改进,**Schraudolph-NG**,通过特定公式利用误差抵消,进一步提高了准确性。 评估表明,速度和准确性之间存在不同的权衡。 虽然标准库实现提供精度,但这些近似值提供了显著的性能提升,使其对于资源受限的应用非常有价值。 方法的选择取决于任务的具体要求,在所需的准确性和计算效率之间取得平衡。

对不起。
相关文章

原文

Survey of fast tanh approximations using Taylor series, Padé approximants, splines, and bitwise manipulation techniques like K-TanH and Schraudolph

The hyperbolic tangent function, \( tanh \), maps any real number to the range (-1, 1) with a smooth S-shaped curve. This property is useful as an activation function in neural networks, where it introduces non-linearity while keeping outputs bounded, and in audio signal processing, where it provides natural-sounding soft clipping for saturation and distortion effects.

In both contexts, speed matters. Neural network inference may evaluate \( tanh \) millions of times per forward pass, and audio processing demands real-time performance at sample rates of 44.1 kHz or higher. The accuracy provided by standard library implementations requires more computation than a tailored approximation.

This post surveys several approaches for approximating \( tanh \): traditional polynomial methods like Taylor series, Padé approximants, and splines, as well as more exotic techniques that exploit the IEEE-754 floating-point representation to achieve impressive speed without too much work.

Survey of Approximations

First, let's take a tour of some of the more common approaches to implementing fast approximations of functions like \( tanh \):

Taylor series

As you might recall from a distant Calculus class, a Taylor series is a polynomial expansion of a function that uses successive derivatives to generate a (infinite) summed sequence of fractions with \( x \) to successive degrees. By taking the first few terms of the sequence, we can find a rough approximation with relatively few operations.

pub fn tanhf(x: f32) -> f32 {
    // Snap to 1 when polynomial begins to deviate at the tails
    if x.abs() > 1.365 {
        return 1f32.copysign(x);
    }

    let t1 = x;
    let t2 = x.powi(3) * (1. / 3.);
    let t3 = x.powi(5) * (2. / 15.);
    let t4 = x.powi(7) * (17. / 315.);
    let t5 = x.powi(9) * (62. / 2835.);
    let t6 = x.powi(11) * (1382. / 155925.);

    t1 - t2 + t3 - t4 + t5 - t6
}

Padé Approximate

Similar to the Taylor series, a Padé approximant is a ratio of two polynomials (i.e. one polynomial divided by another), which results in more accuracy, but requires more operations (including a division).

Below is a Rust adaptation of the tanh approximation in JUCE's FastMathApproximations. In mathematical parlance, this appears to be a "[7/6] Padé approximant," with a 7th-degree numerator and a 6th-degree denominator.

pub fn tanhf(x: f32) -> f32 {
    // This approximation works on a limited range.
    // Use input values only between -5 and +5 for limiting the error.
    if x.abs() > 5. {
        return 1f32.copysign(x);
    }

    let x2 = x * x;
    let numerator = x * (135135. + x2 * (17325. + x2 * (378. + x2)));
    let denominator = 135135. + x2 * (62370. + x2 * (3150. + 28. * x2));
    numerator / denominator
}

Splines

Next, we have splines: piece-wise polynomial functions, which can be used to approximate a function by splitting it into many pieces (or subintervals) and calculating the coefficients for a (cubic) polynomial for each piece.

The example below is borrowed from Efficiently inaccurate approximation of hyperbolic tangent used as transfer function in artificial neural networks (Simos, Tsitouras), which examines the use of splines in approximating \( tanh \) for use as a transfer function in neural networks. The paper mentions that for its use, speed is the main concern, and accuracy is a lesser issue.

In tanhf3, the range [0, 18] is divided into three subintervals, and each is given its own third-degree polynomial. The specific coefficients are provided in the paper and were likely found using software like MATLAB.

pub fn tanhf3(xin: f32) -> f32 {
    const N1: f32 = 0.371025186672900;
    const N2: f32 = 2.572153900248530;
    const N3: f32 = 18.;

    match xin.abs() {
        x if x N1 => {
            -3.695076086125492e-1 * x.powi(3)
                + 1.987219343897867e-2 * x.powi(2)
                + x
        }
        x if x N2 => {
            let n = x - N1;
            5.928356367224758e-2 * n.powi(3)
                - 3.914176949486042e-1 * n.powi(2)
                + 8.621472609449146e-1 * n
                + 3.548881072496229e-1
        }
        x if x N3 => {
            let n = x - N2;
            -3.347599023061577e-6 * n.powi(3)
                + 5.456777761558641e-5 * n.powi(2)
                + 7.066442941005233e-4 * n
                + 9.884026213740197e-1
        }
        _ => 1.,
    }.copysign(xin)
}

Approximating By Format

Now that we've surveyed a few mathematical approaches to approximation, let's take a look at how we can use the established format of floating-point numbers (IEEE-754) to approximate \( tanh \).

Below is a diagram of the binary representation of a 32-bit floating-point number from Wikipedia. For the value of 0.15625, we can see its individual components: sign, exponent, and fraction (also known as 'mantissa').

plot

With the binary form, the nominal value can be derived with the following equation:

$$ (s, E, M) = (−1)^s · 2^E · (1 + M/2^p), $$

where \( s, E, M, p \) are sign, (bias-added) exponent, mantissa, and the number of mantissa bits—all of which are non-negative integers.

K-TanH

K-TanH: Efficient TanH For Deep Learning proposed a hardware-efficient algorithm for approximating \( tanh \) using only integer operations and a small (512-bit) lookup table.

The paper provides the following pseudocode for the algorithm in Algorithm 1:

  1. Input:
    1. Input \( x_i = (s_i,E_i,M_i) \)
    2. Parameter Tables \( T_E,T_r,T_b \)
  2. Output:
    1. Output \( y_o = (s_o,E_o,M_o) \)
  3. If \( |x_i|
  4. i.e., \( (s_o, E_o, M_o) = (s_i,E_i,M_i) \)
  • Else If \( |x_i| > T_2, y_o &LT- s_i \cdot 1 \)
    1. i.e. \( (s_o, E_o, M_o) = (s_i,E_{bias},0) \)
  • Else,
    1. Form bit string \( t \) using lower bits of \( E_i \) and higher bits of \( M_i \).
    2. Fetch parameters \( \theta_t = (E_t,r_t,b_t) \) from \( T_E, T_r, T_b \) using index \( t \).
    3. \( s_o \leftarrow s_i, E_o \leftarrow E_t, M_o \leftarrow (M_i \gg r_t) + b_t \)
  • Return \( y_o \)

    Transforming the pseudocode above into Rust gives us the following:

    pub fn tanhf(x: f32) -> f32 {
        const T1: f32 = 0.25;
        const T2: f32 = 3.75;
    
        // `tanh` is symmetric around zero
        let xa = x.abs();
    
        if xa T1 {
            x
        } else if xa > T2 {
            1f32.copysign(x)
        } else {
            let x: u32 = x.to_bits();
    
            let mi = (x >> 16) & 0b0111_1111;
            let so = x & 0x8000_0000;
    
            // t = bit string using lower bits of `E` and higher bits of `M`
            let t = (x >> 20) & 0b11_111;
            let (et, rt, bt) = unpack(LUT[t as usize]);
    
            let eo = (et as u32) 23;
            let mo = (((mi >> (rt as u32)) as i32 + bt as i32) as u32) 16;
    
            f32::from_bits(so | eo | mo)
        }
    }
    

    With the general algorithm sorted out, the remaining piece is constructing the lookup table (LUT). The values of the lookup table are defined in Table 1 of the paper, and shown below in Rust. To fit the table into 512 bits, as mentioned in the paper, I implemented pack and unpack functions to combine and decompose (respectively) the components of each entry in the table.

    The values in the lookup table are derived through an offline process that finds optimal values for the exponent (\( E_t \)), right-shift amount (\( r_t \)), and bias (\( b_t \)). The process is described in Section 2.4 of the paper:

    Finding the optimized tables is one time offline compute process. Also, we want to fit each table in a register of a general purpose machine for quick access. E.g., to fit each table in a 512-bit register for Intel AVX512 SIMD instructions, we use 5-bit indexing (2 LSBs of exponent and 3 MSBs of mantissa) to create 32 entries (32 intervals of the input magnitude), each holding up to 16 bit integer values. Our parameter values are 8-bit only, so we can create 64 intervals to achieve more accurate approximation. However, experimentally, 32 entries suffices.

    const LUT: [u16; 32] = [
        pack(126, 2, 119),
        pack(126, 4, 122),
        pack(126, 4, 123),
        pack(126, 4, 123),
        pack(126, 6, 126),
        pack(126, 6, 126),
        pack(126, 6, 126),
        pack(126, 6, 126),
        pack(125, 1, 1),
        pack(125, 0, -4),
        pack(125, 0, -6),
        pack(125, 0, -7),
        pack(125, 0, -10),
        pack(125, 0, -12),
        pack(125, 0, -15),
        pack(125, 0, -18),
        pack(125, 0, 112),
        pack(126, 1, -4),
        pack(126, 1, -1),
        pack(126, 1, 2),
        pack(126, 1, 3),
        pack(126, 1, 4),
        pack(126, 1, 4),
        pack(126, 1, 4),
        pack(126, 0, 65),
        pack(126, 1, 72),
        pack(126, 1, 73),
        pack(126, 1, 73),
        pack(126, 2, 88),
        pack(126, 2, 89),
        pack(126, 2, 89),
        pack(126, 4, 110),
    ];
    
    #[inline]
    const fn pack(et: u8, rt: u8, bt: i8) -> u16 {
        let e = match et {
            125 => 0,
            126 => 1,
            _ => panic!("Invalid value for `et`"),
        };
    
        ((((e 4) | rt) as u16) 8) | (bt as u8 as u16)
    }
    
    #[inline]
    const fn unpack(x: u16) -> (u8, u8, i8) {
        let et = (x >> 12) as u8 + 125;
        let rt = (x >> 8) as u8 & 0xF;
        let bt = x as u8 as i8;
    
        (et, rt, bt)
    }
    

    The K-TanH algorithm is particularly well-suited for hardware implementations and SIMD parallelism. The paper demonstrates using Intel AVX512 instructions to fit the entire lookup table in a single 512-bit register, enabling extremely fast lookups. The algorithm would also benefit from the 'bfloat16' (brain floating-point) format, which uses fewer mantissa bits—reducing the table size while maintaining sufficient precision e.g. for deep learning workloads.

    Schraudolph

    In 1999, Nicol Schraudolph published A Fast, Compact Approximation of the Exponential Function, introducing a clever technique that exploits the IEEE-754 floating-point format to approximate \( e^x \) with just a few integer operations. The approach is similar in spirit to the famous "fast inverse square root" hack from Quake III Arena—both treat the bit representation of a float as an integer to perform approximate calculations.

    The key insight is that the binary representation of a floating-point number already encodes something like a logarithm in its exponent field. By carefully manipulating these bits, we can approximate exponential functions without any floating-point arithmetic.

    Here's Schraudolph's original C++ implementation for doubles:

    #define EXP_A (1048576 / M_LN2)
    #define EXP_C 60801
    inline double exponential(double y) {
        union {
            double d;
        #ifdef LITTLE_ENDIAN
            struct { int j, i; } n;
        #elseif
            struct { int i, j; } n;
        #endif
        } _eco;
        _eco.n.i = (int)(EXP_A * y) + (1072693248 - EXP_C);
        _eco.n.j = 0;
        return _eco.d;
    }
    

    After reverse-engineering the constants, transforming from double to float (and porting to Rust), we get:

    pub fn expf(y: f32) -> f32 {
        use core::f32::consts::LN_2;
    
        const BIAS: i16 = f32::MAX_EXP as i16 - 1;
        const MANTISSA_BITS: i16 = f32::MANTISSA_DIGITS as i16 - 1;
        const OFFSET_BITS: i16 = i16::BITS as i16;
    
        const X: i16 = 1 MANTISSA_BITS - OFFSET_BITS);
        const A: f32 = X as f32 / LN_2;
        const B: i16 = X * BIAS;
        const C: i16 = 8; // tuning parameter
        const D: i16 = B - C;
    
        unsafe {
            let y = (A * y).to_int_unchecked::i16>() + D;
    
            core::mem::transmute(
                #[cfg(target_endian = "little")] [0, y],
                #[cfg(target_endian = "big")] [y, 0],
            )
        }
    }
    

    And then, we can express \( tanh \) in terms of \( expf \) i.e. \( tanh(x) = (e^{2x} - 1)/(e^{2x} + 1) \):

    /// Approximating `tanh` via `exp`
    pub fn tanhf(x: f32) -> f32 {
        let y = expf(2. * x);
        (y - 1.) / (y + 1.)
    }
    

    The algorithm works by treating the floating-point bit pattern as an integer. For doubles, Schraudolph splits the 64 bits into two 32-bit integers and operates on just the upper half (which contains the sign, exponent, and upper mantissa bits). The core formula is:

    $$ i = ay + (b - c) $$

    where \( a = 2^{20}/\ln(2) \), \( b = 1023 \cdot 2^{20} \) (encoding the exponent bias), and \( c \) is a tuning parameter that controls the accuracy/speed tradeoff. The original paper recommends \( c = 60801 \).

    For 32-bit floats, we adjust the constants: \( a = 2^7/\ln(2) \), \( b = 127 \cdot 2^7 \), and \( c = 8 \). The Rust implementation above derives these values at compile-time for clarity.

    Schraudolph-NG: Improving Accuracy

    In 2018, Schraudolph himself proposed an improved version in a Stack Overflow comment:

    A good increase in accuracy in my algorithm [...] can be obtained at the cost of an integer subtraction and floating-point division by using FastExpSse(x/2)/FastExpSse(-x/2) instead of FastExpSse(x). The trick here is to set the shift parameter to zero so that the piecewise linear approximations in the numerator and denominator line up to give you substantial error cancellation.

    This variant exploits the identity:

    $$ e^x = \frac{e^{x/2}}{e^{-x/2}} $$

    When both the numerator and denominator use the same piecewise-linear approximation (with zero shift), their errors are correlated. Dividing them causes much of the error to cancel out, yielding a more accurate result at the cost of one extra exp evaluation and a division.

    The implementation below, dubbed NG for 'next generation', demonstrates the computing of expf using this approach, and includes an optional NEON-optimized path for ARM processors:

    use std::f32::consts::*;
    
    #[cfg(target_endian = "little")]
    pub fn expf(x: f32) -> f32 {
        const BIAS: u32 = f32::MAX_EXP as u32 - 1;
        const MANTISSA_BITS: u32 = f32::MANTISSA_DIGITS - 1;
    
        const A: f32 = (1 MANTISSA_BITS) as f32 / LN_2;
        const B: f32 = (BIAS MANTISSA_BITS) as f32;
    
        debug_assert!(A * x // x ~
        #[cfg(not(target_feature = "neon"))]
        {
            f32::from_bits((A / 2. * x + B) as u32) / f32::from_bits((-A / 2. * x + B) as u32)
        }
    
        #[cfg(target_feature = "neon")]
        unsafe {
            use std::arch::aarch64::*;
            use std::mem::transmute;
    
            let a = vcreate_f32(transmute([A / 2., -A / 2.]));
            let x = vdup_n_f32(x);
            let b = vdup_n_f32(B);
    
            let y = vfma_f32(b, a, x);
            let y = vcvt_u32_f32(y);
            let y = vreinterpret_f32_u32(y);
    
            vget_lane_f32::&LT0>(y) / vget_lane_f32::&LT1>(y)
        }
    }
    

    Additional Reading

    Here are a few additional references that I found while researching the history of this particular approach:

    Results

    The plot below shows the approximation error (compared to the standard library tanhf) for each method across the input range. Each approximation was evaluated over a uniform distribution of inputs, measuring both maximum absolute error and average error.

    plot

    Thanks

    Thanks to David Silverman, my teammate on the school project where this analysis originated. Check out his book, Stop Harming Customers: A Compliance Manifesto.

  • 联系我们 contact @ memedata.com