...

Source file src/cmd/compile/internal/inline/inlheur/score_callresult_uses.go

Documentation: cmd/compile/internal/inline/inlheur

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package inlheur
     6  
     7  import (
     8  	"cmd/compile/internal/ir"
     9  	"fmt"
    10  	"os"
    11  )
    12  
    13  // This file contains code to re-score callsites based on how the
    14  // results of the call were used.  Example:
    15  //
    16  //    func foo() {
    17  //       x, fptr := bar()
    18  //       switch x {
    19  //         case 10: fptr = baz()
    20  //         default: blix()
    21  //       }
    22  //       fptr(100)
    23  //     }
    24  //
    25  // The initial scoring pass will assign a score to "bar()" based on
    26  // various criteria, however once the first pass of scoring is done,
    27  // we look at the flags on the result from bar, and check to see
    28  // how those results are used. If bar() always returns the same constant
    29  // for its first result, and if the variable receiving that result
    30  // isn't redefined, and if that variable feeds into an if/switch
    31  // condition, then we will try to adjust the score for "bar" (on the
    32  // theory that if we inlined, we can constant fold / deadcode).
    33  
    34  type resultPropAndCS struct {
    35  	defcs *CallSite
    36  	props ResultPropBits
    37  }
    38  
    39  type resultUseAnalyzer struct {
    40  	resultNameTab map[*ir.Name]resultPropAndCS
    41  	fn            *ir.Func
    42  	cstab         CallSiteTab
    43  	*condLevelTracker
    44  }
    45  
    46  // rescoreBasedOnCallResultUses examines how call results are used,
    47  // and tries to update the scores of calls based on how their results
    48  // are used in the function.
    49  func (csa *callSiteAnalyzer) rescoreBasedOnCallResultUses(fn *ir.Func, resultNameTab map[*ir.Name]resultPropAndCS, cstab CallSiteTab) {
    50  	enableDebugTraceIfEnv()
    51  	rua := &resultUseAnalyzer{
    52  		resultNameTab:    resultNameTab,
    53  		fn:               fn,
    54  		cstab:            cstab,
    55  		condLevelTracker: new(condLevelTracker),
    56  	}
    57  	var doNode func(ir.Node) bool
    58  	doNode = func(n ir.Node) bool {
    59  		rua.nodeVisitPre(n)
    60  		ir.DoChildren(n, doNode)
    61  		rua.nodeVisitPost(n)
    62  		return false
    63  	}
    64  	doNode(fn)
    65  	disableDebugTrace()
    66  }
    67  
    68  func (csa *callSiteAnalyzer) examineCallResults(cs *CallSite, resultNameTab map[*ir.Name]resultPropAndCS) map[*ir.Name]resultPropAndCS {
    69  	if debugTrace&debugTraceScoring != 0 {
    70  		fmt.Fprintf(os.Stderr, "=-= examining call results for %q\n",
    71  			EncodeCallSiteKey(cs))
    72  	}
    73  
    74  	// Invoke a helper to pick out the specific ir.Name's the results
    75  	// from this call are assigned into, e.g. "x, y := fooBar()". If
    76  	// the call is not part of an assignment statement, or if the
    77  	// variables in question are not newly defined, then we'll receive
    78  	// an empty list here.
    79  	//
    80  	names, autoTemps, props := namesDefined(cs)
    81  	if len(names) == 0 {
    82  		return resultNameTab
    83  	}
    84  
    85  	if debugTrace&debugTraceScoring != 0 {
    86  		fmt.Fprintf(os.Stderr, "=-= %d names defined\n", len(names))
    87  	}
    88  
    89  	// For each returned value, if the value has interesting
    90  	// properties (ex: always returns the same constant), and the name
    91  	// in question is never redefined, then make an entry in the
    92  	// result table for it.
    93  	const interesting = (ResultIsConcreteTypeConvertedToInterface |
    94  		ResultAlwaysSameConstant | ResultAlwaysSameInlinableFunc | ResultAlwaysSameFunc)
    95  	for idx, n := range names {
    96  		rprop := props.ResultFlags[idx]
    97  
    98  		if debugTrace&debugTraceScoring != 0 {
    99  			fmt.Fprintf(os.Stderr, "=-= props for ret %d %q: %s\n",
   100  				idx, n.Sym().Name, rprop.String())
   101  		}
   102  
   103  		if rprop&interesting == 0 {
   104  			continue
   105  		}
   106  		if csa.nameFinder.reassigned(n) {
   107  			continue
   108  		}
   109  		if resultNameTab == nil {
   110  			resultNameTab = make(map[*ir.Name]resultPropAndCS)
   111  		} else if _, ok := resultNameTab[n]; ok {
   112  			panic("should never happen")
   113  		}
   114  		entry := resultPropAndCS{
   115  			defcs: cs,
   116  			props: rprop,
   117  		}
   118  		resultNameTab[n] = entry
   119  		if autoTemps[idx] != nil {
   120  			resultNameTab[autoTemps[idx]] = entry
   121  		}
   122  		if debugTrace&debugTraceScoring != 0 {
   123  			fmt.Fprintf(os.Stderr, "=-= add resultNameTab table entry n=%v autotemp=%v props=%s\n", n, autoTemps[idx], rprop.String())
   124  		}
   125  	}
   126  	return resultNameTab
   127  }
   128  
   129  // namesDefined returns a list of ir.Name's corresponding to locals
   130  // that receive the results from the call at site 'cs', plus the
   131  // properties object for the called function. If a given result
   132  // isn't cleanly assigned to a newly defined local, the
   133  // slot for that result in the returned list will be nil. Example:
   134  //
   135  //	call                             returned name list
   136  //
   137  //	x := foo()                       [ x ]
   138  //	z, y := bar()                    [ nil, nil ]
   139  //	_, q := baz()                    [ nil, q ]
   140  //
   141  // In the case of a multi-return call, such as "x, y := foo()",
   142  // the pattern we see from the front end will be a call op
   143  // assigning to auto-temps, and then an assignment of the auto-temps
   144  // to the user-level variables. In such cases we return
   145  // first the user-level variable (in the first func result)
   146  // and then the auto-temp name in the second result.
   147  func namesDefined(cs *CallSite) ([]*ir.Name, []*ir.Name, *FuncProps) {
   148  	// If this call doesn't feed into an assignment (and of course not
   149  	// all calls do), then we don't have anything to work with here.
   150  	if cs.Assign == nil {
   151  		return nil, nil, nil
   152  	}
   153  	funcInlHeur, ok := fpmap[cs.Callee]
   154  	if !ok {
   155  		// TODO: add an assert/panic here.
   156  		return nil, nil, nil
   157  	}
   158  	if len(funcInlHeur.props.ResultFlags) == 0 {
   159  		return nil, nil, nil
   160  	}
   161  
   162  	// Single return case.
   163  	if len(funcInlHeur.props.ResultFlags) == 1 {
   164  		asgn, ok := cs.Assign.(*ir.AssignStmt)
   165  		if !ok {
   166  			return nil, nil, nil
   167  		}
   168  		// locate name being assigned
   169  		aname, ok := asgn.X.(*ir.Name)
   170  		if !ok {
   171  			return nil, nil, nil
   172  		}
   173  		return []*ir.Name{aname}, []*ir.Name{nil}, funcInlHeur.props
   174  	}
   175  
   176  	// Multi-return case
   177  	asgn, ok := cs.Assign.(*ir.AssignListStmt)
   178  	if !ok || !asgn.Def {
   179  		return nil, nil, nil
   180  	}
   181  	userVars := make([]*ir.Name, len(funcInlHeur.props.ResultFlags))
   182  	autoTemps := make([]*ir.Name, len(funcInlHeur.props.ResultFlags))
   183  	for idx, x := range asgn.Lhs {
   184  		if n, ok := x.(*ir.Name); ok {
   185  			userVars[idx] = n
   186  			r := asgn.Rhs[idx]
   187  			if r.Op() == ir.OCONVNOP {
   188  				r = r.(*ir.ConvExpr).X
   189  			}
   190  			if ir.IsAutoTmp(r) {
   191  				autoTemps[idx] = r.(*ir.Name)
   192  			}
   193  			if debugTrace&debugTraceScoring != 0 {
   194  				fmt.Fprintf(os.Stderr, "=-= multi-ret namedef uv=%v at=%v\n",
   195  					x, autoTemps[idx])
   196  			}
   197  		} else {
   198  			return nil, nil, nil
   199  		}
   200  	}
   201  	return userVars, autoTemps, funcInlHeur.props
   202  }
   203  
   204  func (rua *resultUseAnalyzer) nodeVisitPost(n ir.Node) {
   205  	rua.condLevelTracker.post(n)
   206  }
   207  
   208  func (rua *resultUseAnalyzer) nodeVisitPre(n ir.Node) {
   209  	rua.condLevelTracker.pre(n)
   210  	switch n.Op() {
   211  	case ir.OCALLINTER:
   212  		if debugTrace&debugTraceScoring != 0 {
   213  			fmt.Fprintf(os.Stderr, "=-= rescore examine iface call %v:\n", n)
   214  		}
   215  		rua.callTargetCheckResults(n)
   216  	case ir.OCALLFUNC:
   217  		if debugTrace&debugTraceScoring != 0 {
   218  			fmt.Fprintf(os.Stderr, "=-= rescore examine call %v:\n", n)
   219  		}
   220  		rua.callTargetCheckResults(n)
   221  	case ir.OIF:
   222  		ifst := n.(*ir.IfStmt)
   223  		rua.foldCheckResults(ifst.Cond)
   224  	case ir.OSWITCH:
   225  		swst := n.(*ir.SwitchStmt)
   226  		if swst.Tag != nil {
   227  			rua.foldCheckResults(swst.Tag)
   228  		}
   229  
   230  	}
   231  }
   232  
   233  // callTargetCheckResults examines a given call to see whether the
   234  // callee expression is potentially an inlinable function returned
   235  // from a potentially inlinable call. Examples:
   236  //
   237  //	Scenario 1: named intermediate
   238  //
   239  //	   fn1 := foo()         conc := bar()
   240  //	   fn1("blah")          conc.MyMethod()
   241  //
   242  //	Scenario 2: returned func or concrete object feeds directly to call
   243  //
   244  //	   foo()("blah")        bar().MyMethod()
   245  //
   246  // In the second case although at the source level the result of the
   247  // direct call feeds right into the method call or indirect call,
   248  // we're relying on the front end having inserted an auto-temp to
   249  // capture the value.
   250  func (rua *resultUseAnalyzer) callTargetCheckResults(call ir.Node) {
   251  	ce := call.(*ir.CallExpr)
   252  	rname := rua.getCallResultName(ce)
   253  	if rname == nil {
   254  		return
   255  	}
   256  	if debugTrace&debugTraceScoring != 0 {
   257  		fmt.Fprintf(os.Stderr, "=-= staticvalue returns %v:\n",
   258  			rname)
   259  	}
   260  	if rname.Class != ir.PAUTO {
   261  		return
   262  	}
   263  	switch call.Op() {
   264  	case ir.OCALLINTER:
   265  		if debugTrace&debugTraceScoring != 0 {
   266  			fmt.Fprintf(os.Stderr, "=-= in %s checking %v for cci prop:\n",
   267  				rua.fn.Sym().Name, rname)
   268  		}
   269  		if cs := rua.returnHasProp(rname, ResultIsConcreteTypeConvertedToInterface); cs != nil {
   270  
   271  			adj := returnFeedsConcreteToInterfaceCallAdj
   272  			cs.Score, cs.ScoreMask = adjustScore(adj, cs.Score, cs.ScoreMask)
   273  		}
   274  	case ir.OCALLFUNC:
   275  		if debugTrace&debugTraceScoring != 0 {
   276  			fmt.Fprintf(os.Stderr, "=-= in %s checking %v for samefunc props:\n",
   277  				rua.fn.Sym().Name, rname)
   278  			v, ok := rua.resultNameTab[rname]
   279  			if !ok {
   280  				fmt.Fprintf(os.Stderr, "=-= no entry for %v in rt\n", rname)
   281  			} else {
   282  				fmt.Fprintf(os.Stderr, "=-= props for %v: %q\n", rname, v.props.String())
   283  			}
   284  		}
   285  		if cs := rua.returnHasProp(rname, ResultAlwaysSameInlinableFunc); cs != nil {
   286  			adj := returnFeedsInlinableFuncToIndCallAdj
   287  			cs.Score, cs.ScoreMask = adjustScore(adj, cs.Score, cs.ScoreMask)
   288  		} else if cs := rua.returnHasProp(rname, ResultAlwaysSameFunc); cs != nil {
   289  			adj := returnFeedsFuncToIndCallAdj
   290  			cs.Score, cs.ScoreMask = adjustScore(adj, cs.Score, cs.ScoreMask)
   291  
   292  		}
   293  	}
   294  }
   295  
   296  // foldCheckResults examines the specified if/switch condition 'cond'
   297  // to see if it refers to locals defined by a (potentially inlinable)
   298  // function call at call site C, and if so, whether 'cond' contains
   299  // only combinations of simple references to all of the names in
   300  // 'names' with selected constants + operators. If these criteria are
   301  // met, then we adjust the score for call site C to reflect the
   302  // fact that inlining will enable deadcode and/or constant propagation.
   303  // Note: for this heuristic to kick in, the names in question have to
   304  // be all from the same callsite. Examples:
   305  //
   306  //	  q, r := baz()	    x, y := foo()
   307  //	  switch q+r {		a, b, c := bar()
   308  //		...			    if x && y && a && b && c {
   309  //	  }					   ...
   310  //					    }
   311  //
   312  // For the call to "baz" above we apply a score adjustment, but not
   313  // for the calls to "foo" or "bar".
   314  func (rua *resultUseAnalyzer) foldCheckResults(cond ir.Node) {
   315  	namesUsed := collectNamesUsed(cond)
   316  	if len(namesUsed) == 0 {
   317  		return
   318  	}
   319  	var cs *CallSite
   320  	for _, n := range namesUsed {
   321  		rpcs, found := rua.resultNameTab[n]
   322  		if !found {
   323  			return
   324  		}
   325  		if cs != nil && rpcs.defcs != cs {
   326  			return
   327  		}
   328  		cs = rpcs.defcs
   329  		if rpcs.props&ResultAlwaysSameConstant == 0 {
   330  			return
   331  		}
   332  	}
   333  	if debugTrace&debugTraceScoring != 0 {
   334  		nls := func(nl []*ir.Name) string {
   335  			r := ""
   336  			for _, n := range nl {
   337  				r += " " + n.Sym().Name
   338  			}
   339  			return r
   340  		}
   341  		fmt.Fprintf(os.Stderr, "=-= calling ShouldFoldIfNameConstant on names={%s} cond=%v\n", nls(namesUsed), cond)
   342  	}
   343  
   344  	if !ShouldFoldIfNameConstant(cond, namesUsed) {
   345  		return
   346  	}
   347  	adj := returnFeedsConstToIfAdj
   348  	cs.Score, cs.ScoreMask = adjustScore(adj, cs.Score, cs.ScoreMask)
   349  }
   350  
   351  func collectNamesUsed(expr ir.Node) []*ir.Name {
   352  	res := []*ir.Name{}
   353  	ir.Visit(expr, func(n ir.Node) {
   354  		if n.Op() != ir.ONAME {
   355  			return
   356  		}
   357  		nn := n.(*ir.Name)
   358  		if nn.Class != ir.PAUTO {
   359  			return
   360  		}
   361  		res = append(res, nn)
   362  	})
   363  	return res
   364  }
   365  
   366  func (rua *resultUseAnalyzer) returnHasProp(name *ir.Name, prop ResultPropBits) *CallSite {
   367  	v, ok := rua.resultNameTab[name]
   368  	if !ok {
   369  		return nil
   370  	}
   371  	if v.props&prop == 0 {
   372  		return nil
   373  	}
   374  	return v.defcs
   375  }
   376  
   377  func (rua *resultUseAnalyzer) getCallResultName(ce *ir.CallExpr) *ir.Name {
   378  	var callTarg ir.Node
   379  	if sel, ok := ce.Fun.(*ir.SelectorExpr); ok {
   380  		// method call
   381  		callTarg = sel.X
   382  	} else if ctarg, ok := ce.Fun.(*ir.Name); ok {
   383  		// regular call
   384  		callTarg = ctarg
   385  	} else {
   386  		return nil
   387  	}
   388  	r := ir.StaticValue(callTarg)
   389  	if debugTrace&debugTraceScoring != 0 {
   390  		fmt.Fprintf(os.Stderr, "=-= staticname on %v returns %v:\n",
   391  			callTarg, r)
   392  	}
   393  	if r.Op() == ir.OCALLFUNC {
   394  		// This corresponds to the "x := foo()" case; here
   395  		// ir.StaticValue has brought us all the way back to
   396  		// the call expression itself. We need to back off to
   397  		// the name defined by the call; do this by looking up
   398  		// the callsite.
   399  		ce := r.(*ir.CallExpr)
   400  		cs, ok := rua.cstab[ce]
   401  		if !ok {
   402  			return nil
   403  		}
   404  		names, _, _ := namesDefined(cs)
   405  		if len(names) == 0 {
   406  			return nil
   407  		}
   408  		return names[0]
   409  	} else if r.Op() == ir.ONAME {
   410  		return r.(*ir.Name)
   411  	}
   412  	return nil
   413  }
   414  

View as plain text