route/vendor/github.com/aclements/go-moremath/stats/util_test.go

76 lines
1.7 KiB
Go
Raw Normal View History

2017-10-06 15:29:20 +00:00
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package stats
import (
"fmt"
"testing"
"github.com/aclements/go-moremath/internal/mathtest"
"github.com/aclements/go-moremath/vec"
)
var aeq = mathtest.Aeq
var testFunc = mathtest.WantFunc
func testDiscreteCDF(t *testing.T, name string, dist DiscreteDist) {
// Build the expected CDF out of the PMF.
l, h := dist.Bounds()
s := dist.Step()
want := map[float64]float64{l - 0.1: 0, h: 1}
sum := 0.0
for x := l; x < h; x += s {
sum += dist.PMF(x)
want[x] = sum
want[x+s/2] = sum
}
testFunc(t, name, dist.CDF, want)
}
func testInvCDF(t *testing.T, dist Dist, bounded bool) {
inv := InvCDF(dist)
name := fmt.Sprintf("InvCDF(%+v)", dist)
cdfName := fmt.Sprintf("CDF(%+v)", dist)
// Test bounds.
vals := map[float64]float64{-0.01: nan, 1.01: nan}
if !bounded {
vals[0] = -inf
vals[1] = inf
}
testFunc(t, name, inv, vals)
if bounded {
lo, hi := inv(0), inv(1)
vals := map[float64]float64{
lo - 0.01: 0, lo: 0,
hi: 1, hi + 0.01: 1,
}
testFunc(t, cdfName, dist.CDF, vals)
if got := dist.CDF(lo + 0.01); !(got > 0) {
t.Errorf("%s(0)=%v, but %s(%v)=0", name, lo, cdfName, lo+0.01)
}
if got := dist.CDF(hi - 0.01); !(got < 1) {
t.Errorf("%s(1)=%v, but %s(%v)=1", name, hi, cdfName, hi-0.01)
}
}
// Test points between.
vals = map[float64]float64{}
for _, p := range vec.Linspace(0, 1, 11) {
if p == 0 || p == 1 {
continue
}
x := inv(p)
vals[x] = x
}
testFunc(t, fmt.Sprintf("InvCDF(CDF(%+v))", dist),
func(x float64) float64 {
return inv(dist.CDF(x))
},
vals)
}