Files
pulumi-docker-build/enumerate/main.go
2024-08-13 16:37:03 -07:00

122 lines
2.6 KiB
Go

// Copyright 2024, Pulumi Corporation.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"flag"
"fmt"
"go/ast"
"go/parser"
"go/token"
"log"
"math/rand"
"os"
"path/filepath"
"regexp"
"slices"
"strings"
)
var (
rootFlag = flag.String("root", ".", "directory to search")
shardFlag = flag.Int("shard", 0, "shard index to collect tests for")
nFlag = flag.Int("n", 1, "the total number of shards")
seedFlag = flag.Int64("seed", 0, "randomly shuffle tests using a non-zero seed")
)
var re = regexp.MustCompile(`^Test[A-Z_]`)
type testf struct {
path string
name string
}
func main() {
flag.Parse()
root := *rootFlag
shard := *shardFlag
n := *nFlag
if shard >= n {
log.Fatal("shard must be less than n")
}
if shard < 0 || n < 0 {
log.Fatal("must be non-negative")
}
tests := []testf{}
err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() || !strings.HasSuffix(path, "_test.go") {
return nil
}
// Parse the file to find test functions
fileSet := token.NewFileSet()
node, err := parser.ParseFile(fileSet, path, nil, 0)
if err != nil {
return err
}
for _, decl := range node.Decls {
f, ok := decl.(*ast.FuncDecl)
if !ok {
continue
}
name := f.Name.Name
if !re.MatchString(name) {
continue
}
tests = append(tests, testf{path: filepath.Dir(path), name: name})
}
return nil
})
if err != nil {
log.Fatal(err)
}
// Shuffle the tests.
seed := *seedFlag
if seed != 0 {
random := rand.New(rand.NewSource(seed))
for i := range tests {
j := random.Intn(i + 1) //nolint:gosec // Not cryptographic.
tests[i], tests[j] = tests[j], tests[i]
}
}
// Assign tests to our shard.
paths := []string{}
names := []string{}
for idx, test := range tests {
if idx%n != shard {
continue
}
paths = append(paths, "./"+test.path)
names = append(names, test.name)
}
// De-dupe.
slices.Sort(paths)
slices.Sort(names)
paths = slices.Compact(paths)
names = slices.Compact(names)
fmt.Printf("-run ^(%s)$ %s\n", strings.Join(names, "|"), strings.Join(paths, " "))
}