...

Source file src/cmd/compile/internal/inline/inlheur/eclassify.go

Documentation: cmd/compile/internal/inline/inlheur

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package inlheur
     6  
     7  import (
     8  	"cmd/compile/internal/ir"
     9  	"fmt"
    10  	"os"
    11  )
    12  
    13  // ShouldFoldIfNameConstant analyzes expression tree 'e' to see
    14  // whether it contains only combinations of simple references to all
    15  // of the names in 'names' with selected constants + operators. The
    16  // intent is to identify expression that could be folded away to a
    17  // constant if the value of 'n' were available. Return value is TRUE
    18  // if 'e' does look foldable given the value of 'n', and given that
    19  // 'e' actually makes reference to 'n'. Some examples where the type
    20  // of "n" is int64, type of "s" is string, and type of "p" is *byte:
    21  //
    22  //	Simple?		Expr
    23  //	yes			n<10
    24  //	yes			n*n-100
    25  //	yes			(n < 10 || n > 100) && (n >= 12 || n <= 99 || n != 101)
    26  //	yes			s == "foo"
    27  //	yes			p == nil
    28  //	no			n<foo()
    29  //	no			n<1 || n>m
    30  //	no			float32(n)<1.0
    31  //	no			*p == 1
    32  //	no			1 + 100
    33  //	no			1 / n
    34  //	no			1 + unsafe.Sizeof(n)
    35  //
    36  // To avoid complexities (e.g. nan, inf) we stay way from folding and
    37  // floating point or complex operations (integers, bools, and strings
    38  // only). We also try to be conservative about avoiding any operation
    39  // that might result in a panic at runtime, e.g. for "n" with type
    40  // int64:
    41  //
    42  //	1<<(n-9) < 100/(n<<9999)
    43  //
    44  // we would return FALSE due to the negative shift count and/or
    45  // potential divide by zero.
    46  func ShouldFoldIfNameConstant(n ir.Node, names []*ir.Name) bool {
    47  	cl := makeExprClassifier(names)
    48  	var doNode func(ir.Node) bool
    49  	doNode = func(n ir.Node) bool {
    50  		ir.DoChildren(n, doNode)
    51  		cl.Visit(n)
    52  		return false
    53  	}
    54  	doNode(n)
    55  	if cl.getdisp(n) != exprSimple {
    56  		return false
    57  	}
    58  	for _, v := range cl.names {
    59  		if !v {
    60  			return false
    61  		}
    62  	}
    63  	return true
    64  }
    65  
    66  // exprClassifier holds intermediate state about nodes within an
    67  // expression tree being analyzed by ShouldFoldIfNameConstant. Here
    68  // "name" is the name node passed in, and "disposition" stores the
    69  // result of classifying a given IR node.
    70  type exprClassifier struct {
    71  	names       map[*ir.Name]bool
    72  	disposition map[ir.Node]disp
    73  }
    74  
    75  type disp int
    76  
    77  const (
    78  	// no info on this expr
    79  	exprNoInfo disp = iota
    80  
    81  	// expr contains only literals
    82  	exprLiterals
    83  
    84  	// expr is legal combination of literals and specified names
    85  	exprSimple
    86  )
    87  
    88  func (d disp) String() string {
    89  	switch d {
    90  	case exprNoInfo:
    91  		return "noinfo"
    92  	case exprSimple:
    93  		return "simple"
    94  	case exprLiterals:
    95  		return "literals"
    96  	default:
    97  		return fmt.Sprintf("unknown<%d>", d)
    98  	}
    99  }
   100  
   101  func makeExprClassifier(names []*ir.Name) *exprClassifier {
   102  	m := make(map[*ir.Name]bool, len(names))
   103  	for _, n := range names {
   104  		m[n] = false
   105  	}
   106  	return &exprClassifier{
   107  		names:       m,
   108  		disposition: make(map[ir.Node]disp),
   109  	}
   110  }
   111  
   112  // Visit sets the classification for 'n' based on the previously
   113  // calculated classifications for n's children, as part of a bottom-up
   114  // walk over an expression tree.
   115  func (ec *exprClassifier) Visit(n ir.Node) {
   116  
   117  	ndisp := exprNoInfo
   118  
   119  	binparts := func(n ir.Node) (ir.Node, ir.Node) {
   120  		if lex, ok := n.(*ir.LogicalExpr); ok {
   121  			return lex.X, lex.Y
   122  		} else if bex, ok := n.(*ir.BinaryExpr); ok {
   123  			return bex.X, bex.Y
   124  		} else {
   125  			panic("bad")
   126  		}
   127  	}
   128  
   129  	t := n.Type()
   130  	if t == nil {
   131  		if debugTrace&debugTraceExprClassify != 0 {
   132  			fmt.Fprintf(os.Stderr, "=-= *** untyped op=%s\n",
   133  				n.Op().String())
   134  		}
   135  	} else if t.IsInteger() || t.IsString() || t.IsBoolean() || t.HasNil() {
   136  		switch n.Op() {
   137  		// FIXME: maybe add support for OADDSTR?
   138  		case ir.ONIL:
   139  			ndisp = exprLiterals
   140  
   141  		case ir.OLITERAL:
   142  			if _, ok := n.(*ir.BasicLit); ok {
   143  			} else {
   144  				panic("unexpected")
   145  			}
   146  			ndisp = exprLiterals
   147  
   148  		case ir.ONAME:
   149  			nn := n.(*ir.Name)
   150  			if _, ok := ec.names[nn]; ok {
   151  				ndisp = exprSimple
   152  				ec.names[nn] = true
   153  			} else {
   154  				sv := ir.StaticValue(n)
   155  				if sv.Op() == ir.ONAME {
   156  					nn = sv.(*ir.Name)
   157  				}
   158  				if _, ok := ec.names[nn]; ok {
   159  					ndisp = exprSimple
   160  					ec.names[nn] = true
   161  				}
   162  			}
   163  
   164  		case ir.ONOT,
   165  			ir.OPLUS,
   166  			ir.ONEG:
   167  			uex := n.(*ir.UnaryExpr)
   168  			ndisp = ec.getdisp(uex.X)
   169  
   170  		case ir.OEQ,
   171  			ir.ONE,
   172  			ir.OLT,
   173  			ir.OGT,
   174  			ir.OGE,
   175  			ir.OLE:
   176  			// compare ops
   177  			x, y := binparts(n)
   178  			ndisp = ec.dispmeet(x, y)
   179  			if debugTrace&debugTraceExprClassify != 0 {
   180  				fmt.Fprintf(os.Stderr, "=-= meet(%s,%s) = %s for op=%s\n",
   181  					ec.getdisp(x), ec.getdisp(y), ec.dispmeet(x, y),
   182  					n.Op().String())
   183  			}
   184  		case ir.OLSH,
   185  			ir.ORSH,
   186  			ir.ODIV,
   187  			ir.OMOD:
   188  			x, y := binparts(n)
   189  			if ec.getdisp(y) == exprLiterals {
   190  				ndisp = ec.dispmeet(x, y)
   191  			}
   192  
   193  		case ir.OADD,
   194  			ir.OSUB,
   195  			ir.OOR,
   196  			ir.OXOR,
   197  			ir.OMUL,
   198  			ir.OAND,
   199  			ir.OANDNOT,
   200  			ir.OANDAND,
   201  			ir.OOROR:
   202  			x, y := binparts(n)
   203  			if debugTrace&debugTraceExprClassify != 0 {
   204  				fmt.Fprintf(os.Stderr, "=-= meet(%s,%s) = %s for op=%s\n",
   205  					ec.getdisp(x), ec.getdisp(y), ec.dispmeet(x, y),
   206  					n.Op().String())
   207  			}
   208  			ndisp = ec.dispmeet(x, y)
   209  		}
   210  	}
   211  
   212  	if debugTrace&debugTraceExprClassify != 0 {
   213  		fmt.Fprintf(os.Stderr, "=-= op=%s disp=%v\n", n.Op().String(),
   214  			ndisp.String())
   215  	}
   216  
   217  	ec.disposition[n] = ndisp
   218  }
   219  
   220  func (ec *exprClassifier) getdisp(x ir.Node) disp {
   221  	if d, ok := ec.disposition[x]; ok {
   222  		return d
   223  	} else {
   224  		panic("missing node from disp table")
   225  	}
   226  }
   227  
   228  // dispmeet performs a "meet" operation on the data flow states of
   229  // node x and y (where the term "meet" is being drawn from traditional
   230  // lattice-theoretical data flow analysis terminology).
   231  func (ec *exprClassifier) dispmeet(x, y ir.Node) disp {
   232  	xd := ec.getdisp(x)
   233  	if xd == exprNoInfo {
   234  		return exprNoInfo
   235  	}
   236  	yd := ec.getdisp(y)
   237  	if yd == exprNoInfo {
   238  		return exprNoInfo
   239  	}
   240  	if xd == exprSimple || yd == exprSimple {
   241  		return exprSimple
   242  	}
   243  	if xd != exprLiterals || yd != exprLiterals {
   244  		panic("unexpected")
   245  	}
   246  	return exprLiterals
   247  }
   248  

View as plain text