Merge #59
59: Added `MulAdd` and `MulAddAssign` traits r=cuviper a=regexident Both `f32` and `f64` implement fused multiply-add, which computes `(self * a) + b` with only one rounding error. This produces a more accurate result with better performance than a separate multiplication operation followed by an add: ```rust fn mul_add(self, a: f32, b: f32) -> f32[src] ``` It is however not possible to make use of this in a generic context by abstracting over a trait. My concrete use-case is machine learning, [gradient descent](https://en.wikipedia.org/wiki/Gradient_descent) to be specific, where the core operation of updating the gradient could make use of `mul_add` for both its `weights: Vector` as well as its `bias: f32`: ```rust struct Perceptron { weights: Vector, bias: f32, } impl MulAdd<f32, Self> for Vector { // ... } impl Perceptron { fn learn(&mut self, example: Vector, expected: f32, learning_rate: f32) { let alpha = self.error(example, expected, learning_rate); self.weights = example.mul_add(alpha, self.weights); self.bias = self.bias.mul_add(alpha, self.bias) } } ``` (The actual impl of `Vector` would be generic over its value type: `Vector<T>`, thus requiring the trait.) Co-authored-by: Vincent Esche <regexident@gmail.com> Co-authored-by: Josh Stone <cuviper@gmail.com>
This commit is contained in:
commit
a49013e338
|
@ -37,6 +37,7 @@ pub use ops::inv::Inv;
|
||||||
pub use ops::checked::{CheckedAdd, CheckedSub, CheckedMul, CheckedDiv,
|
pub use ops::checked::{CheckedAdd, CheckedSub, CheckedMul, CheckedDiv,
|
||||||
CheckedRem, CheckedNeg, CheckedShl, CheckedShr};
|
CheckedRem, CheckedNeg, CheckedShl, CheckedShr};
|
||||||
pub use ops::wrapping::{WrappingAdd, WrappingMul, WrappingSub};
|
pub use ops::wrapping::{WrappingAdd, WrappingMul, WrappingSub};
|
||||||
|
pub use ops::mul_add::{MulAdd, MulAddAssign};
|
||||||
pub use ops::saturating::Saturating;
|
pub use ops::saturating::Saturating;
|
||||||
pub use sign::{Signed, Unsigned, abs, abs_sub, signum};
|
pub use sign::{Signed, Unsigned, abs, abs_sub, signum};
|
||||||
pub use cast::{AsPrimitive, FromPrimitive, ToPrimitive, NumCast, cast};
|
pub use cast::{AsPrimitive, FromPrimitive, ToPrimitive, NumCast, cast};
|
||||||
|
|
|
@ -2,3 +2,4 @@ pub mod saturating;
|
||||||
pub mod checked;
|
pub mod checked;
|
||||||
pub mod wrapping;
|
pub mod wrapping;
|
||||||
pub mod inv;
|
pub mod inv;
|
||||||
|
pub mod mul_add;
|
||||||
|
|
|
@ -0,0 +1,146 @@
|
||||||
|
/// The fused multiply-add operation.
|
||||||
|
/// Computes (self * a) + b with only one rounding error.
|
||||||
|
/// This produces a more accurate result with better performance
|
||||||
|
/// than a separate multiplication operation followed by an add.
|
||||||
|
///
|
||||||
|
/// Note that `A` and `B` are `Self` by default, but this is not mandatory.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// use std::f32;
|
||||||
|
///
|
||||||
|
/// let m = 10.0_f32;
|
||||||
|
/// let x = 4.0_f32;
|
||||||
|
/// let b = 60.0_f32;
|
||||||
|
///
|
||||||
|
/// // 100.0
|
||||||
|
/// let abs_difference = (m.mul_add(x, b) - (m*x + b)).abs();
|
||||||
|
///
|
||||||
|
/// assert!(abs_difference <= f32::EPSILON);
|
||||||
|
/// ```
|
||||||
|
pub trait MulAdd<A = Self, B = Self> {
|
||||||
|
/// The resulting type after applying the fused multiply-add.
|
||||||
|
type Output;
|
||||||
|
|
||||||
|
/// Performs the fused multiply-add operation.
|
||||||
|
fn mul_add(self, a: A, b: B) -> Self::Output;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The fused multiply-add assignment operation.
|
||||||
|
pub trait MulAddAssign<A = Self, B = Self> {
|
||||||
|
/// Performs the fused multiply-add operation.
|
||||||
|
fn mul_add_assign(&mut self, a: A, b: B);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
impl MulAdd<f32, f32> for f32 {
|
||||||
|
type Output = Self;
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn mul_add(self, a: Self, b: Self) -> Self::Output {
|
||||||
|
f32::mul_add(self, a, b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
impl MulAdd<f64, f64> for f64 {
|
||||||
|
type Output = Self;
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn mul_add(self, a: Self, b: Self) -> Self::Output {
|
||||||
|
f64::mul_add(self, a, b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! mul_add_impl {
|
||||||
|
($trait_name:ident for $($t:ty)*) => {$(
|
||||||
|
impl $trait_name for $t {
|
||||||
|
type Output = Self;
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn mul_add(self, a: Self, b: Self) -> Self::Output {
|
||||||
|
(self * a) + b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)*}
|
||||||
|
}
|
||||||
|
|
||||||
|
mul_add_impl!(MulAdd for isize usize i8 u8 i16 u16 i32 u32 i64 u64);
|
||||||
|
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
impl MulAddAssign<f32, f32> for f32 {
|
||||||
|
#[inline]
|
||||||
|
fn mul_add_assign(&mut self, a: Self, b: Self) {
|
||||||
|
*self = f32::mul_add(*self, a, b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
impl MulAddAssign<f64, f64> for f64 {
|
||||||
|
#[inline]
|
||||||
|
fn mul_add_assign(&mut self, a: Self, b: Self) {
|
||||||
|
*self = f64::mul_add(*self, a, b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! mul_add_assign_impl {
|
||||||
|
($trait_name:ident for $($t:ty)*) => {$(
|
||||||
|
impl $trait_name for $t {
|
||||||
|
#[inline]
|
||||||
|
fn mul_add_assign(&mut self, a: Self, b: Self) {
|
||||||
|
*self = (*self * a) + b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)*}
|
||||||
|
}
|
||||||
|
|
||||||
|
mul_add_assign_impl!(MulAddAssign for isize usize i8 u8 i16 u16 i32 u32 i64 u64);
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn mul_add_integer() {
|
||||||
|
macro_rules! test_mul_add {
|
||||||
|
($($t:ident)+) => {
|
||||||
|
$(
|
||||||
|
{
|
||||||
|
let m: $t = 2;
|
||||||
|
let x: $t = 3;
|
||||||
|
let b: $t = 4;
|
||||||
|
|
||||||
|
assert_eq!(MulAdd::mul_add(m, x, b), (m*x + b));
|
||||||
|
}
|
||||||
|
)+
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
test_mul_add!(usize u8 u16 u32 u64 isize i8 i16 i32 i64);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
fn mul_add_float() {
|
||||||
|
macro_rules! test_mul_add {
|
||||||
|
($($t:ident)+) => {
|
||||||
|
$(
|
||||||
|
{
|
||||||
|
use core::$t;
|
||||||
|
|
||||||
|
let m: $t = 12.0;
|
||||||
|
let x: $t = 3.4;
|
||||||
|
let b: $t = 5.6;
|
||||||
|
|
||||||
|
let abs_difference = (MulAdd::mul_add(m, x, b) - (m*x + b)).abs();
|
||||||
|
|
||||||
|
assert!(abs_difference <= $t::EPSILON);
|
||||||
|
}
|
||||||
|
)+
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
test_mul_add!(f32 f64);
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue