site/vendor/github.com/magefile/mage/parse/parse.go

272 lines
6.1 KiB
Go
Raw Normal View History

2017-12-13 18:43:58 +00:00
package parse
import (
"fmt"
"go/ast"
"go/build"
"go/doc"
"go/parser"
"go/token"
"go/types"
"log"
"os"
"os/exec"
"strings"
mgTypes "github.com/magefile/mage/types"
)
type PkgInfo struct {
Funcs []Function
DefaultIsError bool
DefaultIsContext bool
DefaultName string
DefaultFunc Function
}
// Function represented a job function from a mage file
type Function struct {
Name string
IsError bool
IsContext bool
Synopsis string
Comment string
}
// TemplateString returns code for the template switch to run the target.
// It wraps each target call to match the func(context.Context) error that
// runTarget requires.
func (f Function) TemplateString() string {
if f.IsContext && f.IsError {
out := `wrapFn := func(ctx context.Context) error {
return %s(ctx)
}
err := runTarget(wrapFn)`
return fmt.Sprintf(out, f.Name)
}
if f.IsContext && !f.IsError {
out := `wrapFn := func(ctx context.Context) error {
%s(ctx)
return nil
}
err := runTarget(wrapFn)`
return fmt.Sprintf(out, f.Name)
}
if !f.IsContext && f.IsError {
out := `wrapFn := func(ctx context.Context) error {
return %s()
}
err := runTarget(wrapFn)`
return fmt.Sprintf(out, f.Name)
}
if !f.IsContext && !f.IsError {
out := `wrapFn := func(ctx context.Context) error {
%s()
return nil
}
err := runTarget(wrapFn)`
return fmt.Sprintf(out, f.Name)
}
return `fmt.Printf("Error formatting job code\n")
os.Exit(1)`
}
// Package parses a package
func Package(path string, files []string) (*PkgInfo, error) {
fset := token.NewFileSet()
pkg, err := getPackage(path, files, fset)
if err != nil {
return nil, err
}
info, err := makeInfo(path, fset, pkg.Files)
if err != nil {
return nil, err
}
pi := &PkgInfo{}
p := doc.New(pkg, "./", 0)
for _, f := range p.Funcs {
if f.Recv != "" {
// skip methods
continue
}
if !ast.IsExported(f.Name) {
// skip non-exported functions
continue
}
if typ := voidOrError(f.Decl.Type, info); typ != mgTypes.InvalidType {
pi.Funcs = append(pi.Funcs, Function{
Name: f.Name,
Comment: f.Doc,
Synopsis: doc.Synopsis(f.Doc),
IsError: typ == mgTypes.ErrorType || typ == mgTypes.ContextErrorType,
IsContext: typ == mgTypes.ContextVoidType || typ == mgTypes.ContextErrorType,
})
}
}
setDefault(p, pi, info)
return pi, nil
}
func setDefault(p *doc.Package, pi *PkgInfo, info types.Info) {
for _, v := range p.Vars {
for x, name := range v.Names {
if name != "Default" {
continue
}
spec := v.Decl.Specs[x].(*ast.ValueSpec)
if len(spec.Values) != 1 {
log.Println("warning: default declaration has multiple values")
}
id, ok := spec.Values[0].(*ast.Ident)
if !ok {
log.Println("warning: default declaration is not a function name")
}
for _, f := range pi.Funcs {
if f.Name == id.Name {
pi.DefaultName = f.Name
pi.DefaultIsError = f.IsError
pi.DefaultIsContext = f.IsContext
pi.DefaultFunc = f
return
}
}
log.Println("warning: default declaration does not reference a mage target")
}
}
}
// getPackage returns the non-test package at the given path.
func getPackage(path string, files []string, fset *token.FileSet) (*ast.Package, error) {
fm := make(map[string]bool, len(files))
for _, f := range files {
fm[f] = true
}
filter := func(f os.FileInfo) bool {
return fm[f.Name()]
}
pkgs, err := parser.ParseDir(fset, path, filter, parser.ParseComments)
if err != nil {
return nil, fmt.Errorf("failed to parse directory: %v", err)
}
for name, pkg := range pkgs {
if !strings.HasSuffix(name, "_test") {
return pkg, nil
}
}
return nil, fmt.Errorf("no non-test packages found in %s", path)
}
func makeInfo(dir string, fset *token.FileSet, files map[string]*ast.File) (types.Info, error) {
goroot := os.Getenv("GOROOT")
if goroot == "" {
c := exec.Command("go", "env", "GOROOT")
b, err := c.Output()
if err != nil {
return types.Info{}, fmt.Errorf("failed to get GOROOT from 'go env': %v", err)
}
goroot = strings.TrimSpace(string(b))
if goroot == "" {
return types.Info{}, fmt.Errorf("could not determine GOROOT")
}
}
build.Default.GOROOT = goroot
cfg := types.Config{
Importer: getImporter(fset),
}
info := types.Info{
Types: make(map[ast.Expr]types.TypeAndValue),
Defs: make(map[*ast.Ident]types.Object),
Uses: make(map[*ast.Ident]types.Object),
}
fs := make([]*ast.File, 0, len(files))
for _, v := range files {
fs = append(fs, v)
}
_, err := cfg.Check(dir, fset, fs, &info)
if err != nil {
return info, fmt.Errorf("failed to check types in directory: %v", err)
}
return info, nil
}
// errorOrVoid filters the list of functions to only those that return only an
// error or have no return value, and have no parameters.
func errorOrVoid(fns []*ast.FuncDecl, info types.Info) []*ast.FuncDecl {
fds := []*ast.FuncDecl{}
for _, fn := range fns {
if voidOrError(fn.Type, info) != mgTypes.InvalidType {
fds = append(fds, fn)
}
}
return fds
}
func hasContextParam(ft *ast.FuncType, info types.Info) bool {
if ft.Params.NumFields() == 1 {
ret := ft.Params.List[0]
t := info.TypeOf(ret.Type)
if t != nil && t.String() == "context.Context" {
return true
}
}
return false
}
func hasVoidReturn(ft *ast.FuncType, info types.Info) bool {
res := ft.Results
if res.NumFields() == 0 {
return true
}
return false
}
func hasErrorReturn(ft *ast.FuncType, info types.Info) bool {
res := ft.Results
if res.NumFields() == 1 {
ret := res.List[0]
if len(ret.Names) > 1 {
return false
}
t := info.TypeOf(ret.Type)
if t != nil && t.String() == "error" {
return true
}
}
return false
}
func voidOrError(ft *ast.FuncType, info types.Info) mgTypes.FuncType {
if hasContextParam(ft, info) {
if hasVoidReturn(ft, info) {
return mgTypes.ContextVoidType
}
if hasErrorReturn(ft, info) {
return mgTypes.ContextErrorType
}
}
if ft.Params.NumFields() == 0 {
if hasVoidReturn(ft, info) {
return mgTypes.VoidType
}
if hasErrorReturn(ft, info) {
return mgTypes.ErrorType
}
}
return mgTypes.InvalidType
}