76 lines
1.7 KiB
Go
76 lines
1.7 KiB
Go
// Copyright 2015 The Go Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package stats
|
|
|
|
import (
|
|
"fmt"
|
|
"testing"
|
|
|
|
"github.com/aclements/go-moremath/internal/mathtest"
|
|
"github.com/aclements/go-moremath/vec"
|
|
)
|
|
|
|
var aeq = mathtest.Aeq
|
|
var testFunc = mathtest.WantFunc
|
|
|
|
func testDiscreteCDF(t *testing.T, name string, dist DiscreteDist) {
|
|
// Build the expected CDF out of the PMF.
|
|
l, h := dist.Bounds()
|
|
s := dist.Step()
|
|
want := map[float64]float64{l - 0.1: 0, h: 1}
|
|
sum := 0.0
|
|
for x := l; x < h; x += s {
|
|
sum += dist.PMF(x)
|
|
want[x] = sum
|
|
want[x+s/2] = sum
|
|
}
|
|
|
|
testFunc(t, name, dist.CDF, want)
|
|
}
|
|
|
|
func testInvCDF(t *testing.T, dist Dist, bounded bool) {
|
|
inv := InvCDF(dist)
|
|
name := fmt.Sprintf("InvCDF(%+v)", dist)
|
|
cdfName := fmt.Sprintf("CDF(%+v)", dist)
|
|
|
|
// Test bounds.
|
|
vals := map[float64]float64{-0.01: nan, 1.01: nan}
|
|
if !bounded {
|
|
vals[0] = -inf
|
|
vals[1] = inf
|
|
}
|
|
testFunc(t, name, inv, vals)
|
|
|
|
if bounded {
|
|
lo, hi := inv(0), inv(1)
|
|
vals := map[float64]float64{
|
|
lo - 0.01: 0, lo: 0,
|
|
hi: 1, hi + 0.01: 1,
|
|
}
|
|
testFunc(t, cdfName, dist.CDF, vals)
|
|
if got := dist.CDF(lo + 0.01); !(got > 0) {
|
|
t.Errorf("%s(0)=%v, but %s(%v)=0", name, lo, cdfName, lo+0.01)
|
|
}
|
|
if got := dist.CDF(hi - 0.01); !(got < 1) {
|
|
t.Errorf("%s(1)=%v, but %s(%v)=1", name, hi, cdfName, hi-0.01)
|
|
}
|
|
}
|
|
|
|
// Test points between.
|
|
vals = map[float64]float64{}
|
|
for _, p := range vec.Linspace(0, 1, 11) {
|
|
if p == 0 || p == 1 {
|
|
continue
|
|
}
|
|
x := inv(p)
|
|
vals[x] = x
|
|
}
|
|
testFunc(t, fmt.Sprintf("InvCDF(CDF(%+v))", dist),
|
|
func(x float64) float64 {
|
|
return inv(dist.CDF(x))
|
|
},
|
|
vals)
|
|
}
|