190 lines
3.9 KiB
Go
190 lines
3.9 KiB
Go
|
package parse
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"go/ast"
|
||
|
"go/doc"
|
||
|
"go/parser"
|
||
|
"go/token"
|
||
|
"go/types"
|
||
|
"log"
|
||
|
"os"
|
||
|
"strings"
|
||
|
)
|
||
|
|
||
|
type PkgInfo struct {
|
||
|
Funcs []Function
|
||
|
DefaultIsError bool
|
||
|
DefaultName string
|
||
|
}
|
||
|
|
||
|
type Function struct {
|
||
|
Name string
|
||
|
IsError bool
|
||
|
Synopsis string
|
||
|
Comment string
|
||
|
}
|
||
|
|
||
|
// Package parses a package
|
||
|
func Package(path string, files []string) (*PkgInfo, error) {
|
||
|
fset := token.NewFileSet()
|
||
|
|
||
|
pkg, err := getPackage(path, files, fset)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
info, err := makeInfo(path, fset, pkg.Files)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
pi := &PkgInfo{}
|
||
|
|
||
|
p := doc.New(pkg, "./", 0)
|
||
|
for _, f := range p.Funcs {
|
||
|
if f.Recv != "" {
|
||
|
// skip methods
|
||
|
continue
|
||
|
}
|
||
|
if !ast.IsExported(f.Name) {
|
||
|
// skip non-exported functions
|
||
|
continue
|
||
|
}
|
||
|
if typ := voidOrError(f.Decl.Type, info); typ != InvalidType {
|
||
|
pi.Funcs = append(pi.Funcs, Function{
|
||
|
Name: f.Name,
|
||
|
Comment: f.Doc,
|
||
|
Synopsis: doc.Synopsis(f.Doc),
|
||
|
IsError: voidOrError(f.Decl.Type, info) == ErrorType,
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
setDefault(p, pi, info)
|
||
|
|
||
|
return pi, nil
|
||
|
}
|
||
|
|
||
|
func setDefault(p *doc.Package, pi *PkgInfo, info types.Info) {
|
||
|
for _, v := range p.Vars {
|
||
|
for x, name := range v.Names {
|
||
|
if name != "Default" {
|
||
|
continue
|
||
|
}
|
||
|
spec := v.Decl.Specs[x].(*ast.ValueSpec)
|
||
|
if len(spec.Values) != 1 {
|
||
|
log.Println("warning: default declaration has multiple values")
|
||
|
}
|
||
|
id, ok := spec.Values[0].(*ast.Ident)
|
||
|
if !ok {
|
||
|
log.Println("warning: default declaration is not a function name")
|
||
|
}
|
||
|
for _, f := range pi.Funcs {
|
||
|
if f.Name == id.Name {
|
||
|
pi.DefaultName = f.Name
|
||
|
pi.DefaultIsError = f.IsError
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
log.Println("warning: default declaration does not reference a mage target")
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// getPackage returns the non-test package at the given path.
|
||
|
func getPackage(path string, files []string, fset *token.FileSet) (*ast.Package, error) {
|
||
|
fm := make(map[string]bool, len(files))
|
||
|
for _, f := range files {
|
||
|
fm[f] = true
|
||
|
}
|
||
|
|
||
|
filter := func(f os.FileInfo) bool {
|
||
|
return fm[f.Name()]
|
||
|
}
|
||
|
|
||
|
pkgs, err := parser.ParseDir(fset, path, filter, parser.ParseComments)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
for name, pkg := range pkgs {
|
||
|
if !strings.HasSuffix(name, "_test") {
|
||
|
return pkg, nil
|
||
|
}
|
||
|
}
|
||
|
return nil, fmt.Errorf("no non-test packages found in %s", path)
|
||
|
}
|
||
|
|
||
|
func makeInfo(dir string, fset *token.FileSet, files map[string]*ast.File) (types.Info, error) {
|
||
|
cfg := types.Config{
|
||
|
Importer: getImporter(fset),
|
||
|
}
|
||
|
|
||
|
info := types.Info{
|
||
|
Types: make(map[ast.Expr]types.TypeAndValue),
|
||
|
Defs: make(map[*ast.Ident]types.Object),
|
||
|
Uses: make(map[*ast.Ident]types.Object),
|
||
|
}
|
||
|
|
||
|
fs := make([]*ast.File, 0, len(files))
|
||
|
for _, v := range files {
|
||
|
fs = append(fs, v)
|
||
|
}
|
||
|
|
||
|
_, err := cfg.Check(dir, fset, fs, &info)
|
||
|
return info, err
|
||
|
}
|
||
|
|
||
|
// errorOrVoid filters the list of functions to only those that return only an
|
||
|
// error or have no return value, and have no paramters.
|
||
|
func errorOrVoid(fns []*ast.FuncDecl, info types.Info) []*ast.FuncDecl {
|
||
|
fds := []*ast.FuncDecl{}
|
||
|
|
||
|
for _, fn := range fns {
|
||
|
if voidOrError(fn.Type, info) != InvalidType {
|
||
|
fds = append(fds, fn)
|
||
|
}
|
||
|
}
|
||
|
return fds
|
||
|
}
|
||
|
|
||
|
// FuncType is the type of a function that mage understands.
|
||
|
type FuncType int
|
||
|
|
||
|
// FuncTypes
|
||
|
const (
|
||
|
InvalidType FuncType = iota
|
||
|
VoidType
|
||
|
ErrorType
|
||
|
)
|
||
|
|
||
|
func voidOrError(ft *ast.FuncType, info types.Info) FuncType {
|
||
|
// look for functions with no params
|
||
|
if ft.Params.NumFields() > 0 {
|
||
|
return InvalidType
|
||
|
}
|
||
|
|
||
|
// look for functions with 0 or 1 return values
|
||
|
res := ft.Results
|
||
|
if res.NumFields() > 1 {
|
||
|
return InvalidType
|
||
|
}
|
||
|
// 0 return value is ok
|
||
|
if res.NumFields() == 0 {
|
||
|
return VoidType
|
||
|
}
|
||
|
// if 1 return value, look for those that return an error
|
||
|
ret := res.List[0]
|
||
|
|
||
|
// handle (a, b, c int)
|
||
|
if len(ret.Names) > 1 {
|
||
|
return InvalidType
|
||
|
}
|
||
|
t := info.TypeOf(ret.Type)
|
||
|
if t != nil && t.String() == "error" {
|
||
|
return ErrorType
|
||
|
}
|
||
|
return InvalidType
|
||
|
}
|