// Copyright 2017 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // This program generates a test to verify that the standard comparison // operators properly handle one const operand. The test file should be // generated with a known working version of go. // launch with `go run cmpConstGen.go` a file called cmpConst.go // will be written into the parent directory containing the tests package main import ( "bytes" "fmt" "go/format" "log" "math/big" "sort" ) const ( maxU64 = (1 << 64) - 1 maxU32 = (1 << 32) - 1 maxU16 = (1 << 16) - 1 maxU8 = (1 << 8) - 1 maxI64 = (1 << 63) - 1 maxI32 = (1 << 31) - 1 maxI16 = (1 << 15) - 1 maxI8 = (1 << 7) - 1 minI64 = -(1 << 63) minI32 = -(1 << 31) minI16 = -(1 << 15) minI8 = -(1 << 7) ) func cmp(left *big.Int, op string, right *big.Int) bool { switch left.Cmp(right) { case -1: // less than return op == "<" || op == "<=" || op == "!=" case 0: // equal return op == "==" || op == "<=" || op == ">=" case 1: // greater than return op == ">" || op == ">=" || op == "!=" } panic("unexpected comparison value") } func inRange(typ string, val *big.Int) bool { min, max := &big.Int{}, &big.Int{} switch typ { case "uint64": max = max.SetUint64(maxU64) case "uint32": max = max.SetUint64(maxU32) case "uint16": max = max.SetUint64(maxU16) case "uint8": max = max.SetUint64(maxU8) case "int64": min = min.SetInt64(minI64) max = max.SetInt64(maxI64) case "int32": min = min.SetInt64(minI32) max = max.SetInt64(maxI32) case "int16": min = min.SetInt64(minI16) max = max.SetInt64(maxI16) case "int8": min = min.SetInt64(minI8) max = max.SetInt64(maxI8) default: panic("unexpected type") } return cmp(min, "<=", val) && cmp(val, "<=", max) } func getValues(typ string) []*big.Int { Uint := func(v uint64) *big.Int { return big.NewInt(0).SetUint64(v) } Int := func(v int64) *big.Int { return big.NewInt(0).SetInt64(v) } values := []*big.Int{ // limits Uint(maxU64), Uint(maxU64 - 1), Uint(maxI64 + 1), Uint(maxI64), Uint(maxI64 - 1), Uint(maxU32 + 1), Uint(maxU32), Uint(maxU32 - 1), Uint(maxI32 + 1), Uint(maxI32), Uint(maxI32 - 1), Uint(maxU16 + 1), Uint(maxU16), Uint(maxU16 - 1), Uint(maxI16 + 1), Uint(maxI16), Uint(maxI16 - 1), Uint(maxU8 + 1), Uint(maxU8), Uint(maxU8 - 1), Uint(maxI8 + 1), Uint(maxI8), Uint(maxI8 - 1), Uint(0), Int(minI8 + 1), Int(minI8), Int(minI8 - 1), Int(minI16 + 1), Int(minI16), Int(minI16 - 1), Int(minI32 + 1), Int(minI32), Int(minI32 - 1), Int(minI64 + 1), Int(minI64), // other possibly interesting values Uint(1), Int(-1), Uint(0xff << 56), Uint(0xff << 32), Uint(0xff << 24), } sort.Slice(values, func(i, j int) bool { return values[i].Cmp(values[j]) == -1 }) var ret []*big.Int for _, val := range values { if !inRange(typ, val) { continue } ret = append(ret, val) } return ret } func sigString(v *big.Int) string { var t big.Int t.Abs(v) if v.Sign() == -1 { return "neg" + t.String() } return t.String() } func main() { types := []string{ "uint64", "uint32", "uint16", "uint8", "int64", "int32", "int16", "int8", } w := new(bytes.Buffer) fmt.Fprintf(w, "// Code generated by gen/cmpConstGen.go. DO NOT EDIT.\n\n") fmt.Fprintf(w, "package main;\n") fmt.Fprintf(w, "import (\"testing\"; \"reflect\"; \"runtime\";)\n") fmt.Fprintf(w, "// results show the expected result for the elements left of, equal to and right of the index.\n") fmt.Fprintf(w, "type result struct{l, e, r bool}\n") fmt.Fprintf(w, "var (\n") fmt.Fprintf(w, " eq = result{l: false, e: true, r: false}\n") fmt.Fprintf(w, " ne = result{l: true, e: false, r: true}\n") fmt.Fprintf(w, " lt = result{l: true, e: false, r: false}\n") fmt.Fprintf(w, " le = result{l: true, e: true, r: false}\n") fmt.Fprintf(w, " gt = result{l: false, e: false, r: true}\n") fmt.Fprintf(w, " ge = result{l: false, e: true, r: true}\n") fmt.Fprintf(w, ")\n") operators := []struct{ op, name string }{ {"<", "lt"}, {"<=", "le"}, {">", "gt"}, {">=", "ge"}, {"==", "eq"}, {"!=", "ne"}, } for _, typ := range types { // generate a slice containing valid values for this type fmt.Fprintf(w, "\n// %v tests\n", typ) values := getValues(typ) fmt.Fprintf(w, "var %v_vals = []%v{\n", typ, typ) for _, val := range values { fmt.Fprintf(w, "%v,\n", val.String()) } fmt.Fprintf(w, "}\n") // generate test functions for _, r := range values { // TODO: could also test constant on lhs. sig := sigString(r) for _, op := range operators { // no need for go:noinline because the function is called indirectly fmt.Fprintf(w, "func %v_%v_%v(x %v) bool { return x %v %v; }\n", op.name, sig, typ, typ, op.op, r.String()) } } // generate a table of test cases fmt.Fprintf(w, "var %v_tests = []struct{\n", typ) fmt.Fprintf(w, " idx int // index of the constant used\n") fmt.Fprintf(w, " exp result // expected results\n") fmt.Fprintf(w, " fn func(%v) bool\n", typ) fmt.Fprintf(w, "}{\n") for i, r := range values { sig := sigString(r) for _, op := range operators { fmt.Fprintf(w, "{idx: %v,", i) fmt.Fprintf(w, "exp: %v,", op.name) fmt.Fprintf(w, "fn: %v_%v_%v},\n", op.name, sig, typ) } } fmt.Fprintf(w, "}\n") } // emit the main function, looping over all test cases fmt.Fprintf(w, "// TestComparisonsConst tests results for comparison operations against constants.\n") fmt.Fprintf(w, "func TestComparisonsConst(t *testing.T) {\n") for _, typ := range types { fmt.Fprintf(w, "for i, test := range %v_tests {\n", typ) fmt.Fprintf(w, " for j, x := range %v_vals {\n", typ) fmt.Fprintf(w, " want := test.exp.l\n") fmt.Fprintf(w, " if j == test.idx {\nwant = test.exp.e\n}") fmt.Fprintf(w, " else if j > test.idx {\nwant = test.exp.r\n}\n") fmt.Fprintf(w, " if test.fn(x) != want {\n") fmt.Fprintf(w, " fn := runtime.FuncForPC(reflect.ValueOf(test.fn).Pointer()).Name()\n") fmt.Fprintf(w, " t.Errorf(\"test failed: %%v(%%v) != %%v [type=%v i=%%v j=%%v idx=%%v]\", fn, x, want, i, j, test.idx)\n", typ) fmt.Fprintf(w, " }\n") fmt.Fprintf(w, " }\n") fmt.Fprintf(w, "}\n") } fmt.Fprintf(w, "}\n") // gofmt result b := w.Bytes() src, err := format.Source(b) if err != nil { fmt.Printf("%s\n", b) panic(err) } // write to file err = os.WriteFile("../cmpConst_test.go", src, 0666) if err != nil { log.Fatalf("can't write output: %v\n", err) } }