Haoyi's Programming Blog

Table of Contents

How an Optimizing Compiler Works

Posted 2019-11-09
Working with Databases using Scala and Quill

Optimizing compilers are a mainstay of modern software: allowing a programmer to write code in a language that makes sense to them, while transforming it into a form that makes sense for the underlying hardware to run efficiently. The optimizing compiler's job is to figure out what your input program does, and create an output program that it knows will do the same thing, just faster.

This post will walk through some of the core inference techniques within an optimizing compiler: how to model a program in a way that is easy to work with, infer facts about your program, and using those facts to make your program smaller and faster.


This blog post is a companion to the following conference talk:

Program optimizers can run in any number of places: as a step within a larger compiler's process (e.g. the Scala Optimizer), as a separate standalone program you run after compiler and before execution (e.g. Proguard) or as part of the runtime that optimizes the program while it is already being executed (e.g. the JVM JIT compiler). The constraints imposed on optimizers in each case differ, but their output goal is the same: to take an input program and convert it to an output program that does the same thing, faster.

In this post, we will first walk through a few sample optimizations on a strawman program, to get an intuition of what typical optimizers do and how to identify and perform such optimizations manually. We will then examine a few different ways in which programs can be represented, and finally look at the algorithms and techniques we can use to infer facts about these programs, and then simplify them to be both smaller and faster.

A Strawman Program

This post will use the Java programming language for all our examples. Java is a common language, compiles to a relatively simple "assembly" - Java Bytecode - and is uniquitous in the industry. This will give us a good canvas on which we can explore optimization compilation techniques on real, runnable examples. The techniques described here apply equally to optimizing almost any other language.

To begin with, let us consider the following strawman program. This program runs a bunch of logic, logging to the standard output in the process, and finally returning a computed result. The program itself isn't meaningful, but will serve as an illustration of what kind of things you can do to optimize it while preserving the existing behavior:

static int main(int n){
  int count = 0, total = 0, multiplied = 0;
  Logger logger = new PrintLogger();
  while(count < n){
    count += 1;
    multiplied *= count;
    if (multiplied < 100) logger.log(count);
    total += ackermann(2, 2);
    total += ackermann(multiplied, n);
    int d1 = ackermann(n, 1);
    total += d1 * multiplied;
    int d2 = ackermann(n, count);
    if (count % 2 == 0) total += d2;
  }
  return total;
}

// https://en.wikipedia.org/wiki/Ackermann_function
static int ackermann(int m, int n){
  if (m == 0) return n + 1;
  else if (n == 0) return ackermann(m - 1, 1);
  else return ackermann(m - 1, ackermann(m, n - 1));
}

interface Logger{
  public Logger log(Object a);
}
static class PrintLogger implements Logger{
  public PrintLogger log(Object a){  System.out.println(a); return this; }
}
static class ErrLogger implements Logger{
  public ErrLogger log(Object a){ System.err.println(a); return this; }
}

For now, we will assume that this program is all there is: there are no other pieces of code calling into this, and this standalone program just enters main, performs its work, and returns. With that in mind, let us begin looking for optimizations!

Example Optimizations

Type Inference & Inlining

The first thing you may notice is the variable logger has a type that's im-precise: while labelled as a Logger, we can infer from the code that it is specific subclass: a PrintLogger:

-  Logger logger = new PrintLogger();
+  PrintLogger logger = new PrintLogger();

Now that we know logger is a PrintLogger, we know that the call to logger.log can only have a single implementation, so we can inline it:

-    if (multiplied < 100) logger.logcount);
+    if (multiplied < 100) System.out.println(count);

This lets us make the program smaller by removing the unnecessary ErrLogger class since we never use it, and also removing the various public Logger log methods since we've inlined it into our only callsite.

Constant Folding

Separately, we can see that while count and total change in the course of the program, multiplied does not: it starts off 0, and every time it is multiplied via multiplied = multiplied * count it remains 0. We can thus substitute 0 for it throughout the program:

 static int main(int n){
-  int count = 0, total = 0, multiplied = 0;
+  int count = 0, total = 0;
   PrintLogger logger = new PrintLogger();
   while(count < n){
     count += 1;
-     multiplied *= count;
-    if (multiplied < 100) System.out.println(count);
+    if (0 < 100) logger.log(count);
     total += ackermann(2, 2);
-    total += ackermann(multiplied, n);
+    total += ackermann(0, n);
     int d1 = ackermann(n, 1);
-     total += d1 * multiplied;
     int d2 = ackermann(n, count);
     if (count % 2 == 0) total += d2;
   }
   return total;
 }

As a consequence, we can see that d1 * multiplied is always 0, and thus total += d1 * multiplied does nothing, and can be removed

-    total += d1 * multiplied

Dead Code Elimination

Now that we've constant folded multiplied and see that total += d1 * multiplied does nothing, we can follow up and remove the definition of int d1:

-     int d1 = ackermann(n, 1);

This no longer plays a part in the program, and since ackermann is a pure function, removing it does not change the output of the program in any way.

Similarly, after inlining logger.log earlier, logger itself is now un-used and can be eliminated:

-   PrintLogger logger = new PrintLogger();

Branch Elimination

Since The first conditional in our loop is now contingent on 0 < 100. Since that is always true, we can simply remove the conditional:

- if (0 < 100) System.out.println(count);
+ System.out.println(count);

In general, any conditional branch that is always true we can inline the body of the conditional in its place, and for branches that are always false we can remove the conditional together with its body entirely.

Partial Evaluation

Let us take a look at the three remaining calls to ackermann:

    total += ackermann(2, 2);
    total += ackermann(0, n);
    int d2 = ackermann(n, count);

Thus, given the call to ackermann, defined as:

static int ackermann(int m, int n){
  if (m == 0) return n + 1;
  else if (n == 0) return ackermann(m - 1, 1);
  else return ackermann(m - 1, ackermann(m, n - 1));
}

We can simplify this into:

-    total += ackermann(2, 2);
+    total += 7
-    total += ackermann(0, n);
+    total += n + 1
     int d2 = ackermann(n, count);

Late Scheduling

Lastly, we can see that the definition of d2 only gets used in the if (count % 2 == 0) conditional. Since the computation of ackermann is pure, we can thus move that call into the conditional, so that it doesn't get computed unless it actually gets used:

-    int d2 = ackermann(n, count);
-    if (count % 2 == 0) total += d2;
+    if (count % 2 == 0) {
+      int d2 = ackermann(n, count);
+      total += d2;
+    }

This lets us avoid half the calls to ackermann(n, count), speeding things up by not calling it when not necessary.

In contrast, the System.out.println function is not pure, and thus we cannot move it in or out of conditionals without changing program semantics.

Actual Optimized Output

Putting together all these optimizations, the final output code is as follows:

static int main(int n){
  int count = 0, total = 0;
  while(count < n){
    count += 1;
    System.out.println(count);
    total += 7;
    total += n + 1;
    if (count % 2 == 0) {
      total += d2;
      int d2 = ackermann(n, count);
    }
  }
  return total;
}

static int ackermann(int m, int n){
  if (m == 0) return n + 1;
  else if (n == 0) return ackermann(m - 1, 1);
  else return ackermann(m - 1, ackermann(m, n - 1));
}

While all the optimizations we did above were done by hand, it is entirely possible to do them automatically. Below is the de-compiled output of a prototype optimizer I have written for JVM programs:

static int main(int var0) {
  new Demo.PrintLogger();
  int var1 = 0;

  int var3;
  for(int var2 = 0; var2 < var0; var2 = var3) {
    System.out.println(var3 = 1 + var2);
    int var10000 = var3 % 2;
    int var7 = var1 + 7 + var0 + 1;
    var1 = var10000 == 0 ? var7 + ackermann(var0, var3) : var7;
  }

  return var1;
}

static int ackermann__I__TI1__I(int var0) {
  if (var0 == 0) return 2;
  else return ackermann(var0 - 1, var0 == 0 ? 1 : ackermann__I__TI1__I(var0 - 1););
}

static int ackermann(int var0, int var1) {
  if (var0 == 0) return var1 + 1;
  else return var1 == 0 ? ackermann__I__TI1__I(var0 - 1) : ackermann(var0 - 1, ackermann(var0, var1 - 1));
}

static class PrintLogger implements Demo.Logger {}

interface Logger {}

The de-compiled code looks slightly different from the hand-optimized version. There are some things it didn't quite manage to optimize away (e.g. the un-used call to new PrintLogger), and some things it did slightly differently (e.g. split out ackermann and ackermann__I__TI1__I) but otherwise this automated program optimizer has done almost all the same optimizations that we just did by hand, using our own logic and intelligence.

This begs the question: how?

Intermediate Representations

If you try to build a program optimizer, the first question will encounter is possibly the most important: what, exactly, is "a program"?

You are no doubt used to writing and modifying programs as source code. You have definitely executed them as compiled binaries, and perhaps had to peek into the binaries at some point in order to debug issues. You may have encountered the programs in syntax tree form, as three-address codes, A-Normal Form, Continuation Passing Style, or Single Static Assignment form.

There is a whole zoo of different ways you can represent programs. Here we will discuss a few of the most important ways you can represent "a program" within your program optimizer:

Sourcecode

static int ackermann(int m, int n){
  if (m == 0) return n + 1;
  else if (n == 0) return ackermann(m - 1, 1);
  else return ackermann(m - 1, ackermann(m, n - 1));
}

Un-compiled source code is, itself, a representation of your program. Source code tends to be relatively compact, and human readable, but has two issues:

These two factors make source code a difficult form of program for a program optimizer to work with. While you can perform program transformations on source, e.g. using regexes to try and identify patterns and replace them, the first point above means it is difficult to reliably identify patterns in the presence of all the irrelevant detail, and the second point makes it very easy to mess up the replacement and be left with a invalid output program.

While these restrictions are acceptable for program transformers which are run with manual oversight, e.g. using Codemod to refactor and transform a codebase, they make it infeasible to use source code as the primary data model of an automated program optimizer.

Abstract Syntax Trees

static int ackermann(int m, int n){
  if (m == 0) return n + 1;
  else if (n == 0) return ackermann(m - 1, 1);
  else return ackermann(m - 1, ackermann(m, n - 1));
}

IfElse(
    cond = BinOp(Ident("m"), "=", Literal(0)),
    then = Return(BinOp(Ident("n"), "+", Literal(1)),
    else = IfElse(
        cond = BinOp(Ident("n"), "=", Literal(0)),
        then = Return(Call("ackermann", BinOp(Ident("m"), "-", Literal(1)), Literal(1)),
        else = Return(
            Call(
                "ackermann",
                BinOp(Ident("m"), "-", Literal(1)),
                Call("ackermann", Ident("m"), BinOp(Ident("n"), "-", Literal(1)))
            )
        )
    )
)
G Block block IfElse0 IfElse Block->IfElse0 BinOp0 == IfElse0->BinOp0 Return0 Return IfElse0->Return0 IfElse1 IfElse IfElse0->IfElse1 m0 m BinOp0->m0 literal0_0 0 BinOp0->literal0_0 BinOp01 + Return0->BinOp01 n_01 n BinOp01->n_01 literal1_01 1 BinOp01->literal1_01 BinOp1 == IfElse1->BinOp1 Return1 Return IfElse1->Return1 Return2 Return IfElse1->Return2 n1 n BinOp1->n1 literal1_0 0 BinOp1->literal1_0 ackermann1 ackermann Return1->ackermann1 BinOp2 - ackermann1->BinOp2 literal1_2 2 ackermann1->literal1_2 n2 m BinOp2->n2 literal1_1 1 BinOp2->literal1_1 ... ... Return2->...

Abstract Syntax Trees (ASTs) are another common intermediate format. One level up the abstraction ladder from source code, ASTs typically discard all formatting details of the sourcecode such as whitespace or comments, but preserve things such as local variable names that get discarded in more abstract representations.

Like source code, ASTs suffer from the ability to encode extraneous information that doesn't affect the actual semantics of your program. For example, the following two programs are semantically identical, differing only by the names of their local variables, but yet have different ASTs

static int ackermannA(int m, int n){
  int p = n;
  int q = m;
  if (q == 0) return p + 1;
  else if (p == 0) return ackermannA(q - 1, 1);
  else return ackermannA(q - 1, ackermannA(q, p - 1));
}
Block(
    Assign("p", Ident("n")),
    Assign("q", Ident("m")),
    IfElse(
        cond = BinOp(Ident("q"), "==", Literal(0)),
        then = Return(BinOp(Ident("p"), "+", Literal(1)),
        else = IfElse(
            cond = BinOp(Ident("p"), "==", Literal(0)),
            then = Return(Call("ackermann", BinOp(Ident("q"), "-", Literal(1)), Literal(1)),
            else = Return(
                Call(
                    "ackermann",
                    BinOp(Ident("q"), "-", Literal(1)),
                    Call("ackermann", Ident("q"), BinOp(Ident("p"), "-", Literal(1)))
                )
            )
        )
    )
)
static int ackermannB(int m, int n){
  int r = n;
  int s = m;
  if (s == 0) return r + 1;
  else if (r == 0) return ackermannB(s - 1, 1);
  else return ackermannB(s - 1, ackermannB(s, r - 1));
}
Block(
    Assign("r", Ident("n")),
    Assign("s", Ident("m")),
    IfElse(
        cond = BinOp(Ident("s"), "==", Literal(0)),
        then = Return(BinOp(Ident("r"), "+", Literal(1)),
        else = IfElse(
            cond = BinOp(Ident("r"), "==", Literal(0)),
            then = Return(Call("ackermann", BinOp(Ident("s"), "-", Literal(1)), Literal(1)),
            else = Return(
                Call(
                    "ackermann",
                    BinOp(Ident("s"), "-", Literal(1)),
                    Call("ackermann", Ident("s"), BinOp(Ident("r"), "-", Literal(1)))
                )
            )
        )
    )
)

The key insight here is that while the ASTs are structured as trees, they contain some nodes that semantically do not behave as trees: Ident("r") and Ident("s") have their meaning defined not by the contents of their subtree, but by the Assign("r", ...) and Assign("s", ...) nodes seen earlier. In effect, these Idents and Assigns have additional semantic edges between them, which are just as important as the edges in the AST's tree structure:

G Block block Assign0 Assign(p) Block->Assign0 Assign1 Assign(q) Block->Assign1 IfElse0 IfElse Block->IfElse0 Assign0Right n Assign0->Assign0Right Assign1Right m Assign1->Assign1Right BinOp0 == IfElse0->BinOp0 Return0 Return IfElse0->Return0 IfElse1 IfElse IfElse0->IfElse1 m0 q BinOp0->m0 literal0_0 0 BinOp0->literal0_0 m0->Assign1 BinOp01 + Return0->BinOp01 n_01 p BinOp01->n_01 literal1_01 1 BinOp01->literal1_01 n_01->Assign0 BinOp1 == IfElse1->BinOp1 Return1 Return IfElse1->Return1 Return2 Return IfElse1->Return2 n1 p BinOp1->n1 literal1_0 0 BinOp1->literal1_0 ackermann1 ackermann Return1->ackermann1 BinOp2 - ackermann1->BinOp2 literal1_2 2 ackermann1->literal1_2 n2 q BinOp2->n2 literal1_1 1 BinOp2->literal1_1 ... ... Return2->...

these edges form a generalized graph structure, including cycles if there are recursive function definitions.

Bytecode

Since we are using Java as our base language, compiled programs end up being stored as Java Bytecode in .class files.

Recall the ackermann function we saw earlier:

static int ackermann(int m, int n){
  if (m == 0) return n + 1;
  else if (n == 0) return ackermann(m - 1, 1);
  else return ackermann(m - 1, ackermann(m, n - 1));
}

This compiles to bytecode that looks like this:

   0: iload_0
   1: ifne          8
   4: iload_1
   5: iconst_1
   6: iadd
   7: ireturn
   8: iload_1
   9: ifne          20
  12: iload_0
  13: iconst_1
  14: isub
  15: iconst_1
  16: invokestatic ackermann:(II)I
  19: ireturn
  20: iload_0
  21: iconst_1
  22: isub
  23: iload_0
  24: iload_1
  25: iconst_1
  26: isub
  27: invokestatic ackermann:(II)I
  30: invokestatic ackermann:(II)I
  33: ireturn

The Java Virtual Machine (JVM) that Java bytecode runs on is a mixed stack/register machine: there is an operand STACK where values can be operated on, and an array of LOCALS where values can be stored. A function begins with its N parameters in the first N local variable slots, and executes by moving things onto the stack, operating on them, and putting them back into locals, eventually calling return to return a value on the operand stack to the caller.

If we annotate the above bytecode to represent the values moving between the stack and the local variables table, it looks as follows:

  BYTECODE                                LOCALS          STACK
                                          |a0|a1|         |
   0: iload_0                             |a0|a1|         |a0|
   1: ifne          8                     |a0|a1|         |
   4: iload_1                             |a0|a1|         |a1|
   5: iconst_1                            |a0|a1|         |a1| 1|
   6: iadd                                |a0|a1|         |v1|
   7: ireturn                             |a0|a1|         |
   8: iload_1                             |a0|a1|         |a1|
   9: ifne          20                    |a0|a1|         |
  12: iload_0                             |a0|a1|         |a0|
  13: iconst_1                            |a0|a1|         |a0| 1|
  14: isub                                |a0|a1|         |v2|
  15: iconst_1                            |a0|a1|         |v2| 1|
  16: invokestatic ackermann:(II)I        |a0|a1|         |v3|
  19: ireturn                             |a0|a1|         |
  20: iload_0                             |a0|a1|         |a0|
  21: iconst_1                            |a0|a1|         |a0| 1|
  22: isub                                |a0|a1|         |v4|
  23: iload_0                             |a0|a1|         |v4|a0|
  24: iload_1                             |a0|a1|         |v4|a0|a1|
  25: iconst_1                            |a0|a1|         |v4|a0|a1| 1|
  26: isub                                |a0|a1|         |v4|a0|v5|
  27: invokestatic ackermann:(II)I        |a0|a1|         |v4|v6|
  30: invokestatic ackermann:(II)I        |a0|a1|         |v7|
  33: ireturn                             |a0|a1|         |

Here, I have used a0 and a1 to represent the functions arguments, which are all stored in the LOCALS table at the start of a function. 1 represents the constants loaded via iconst_1, and v1 to v7 are the computed intermediate values. We can see that there are three ireturn instructions, returning v1, v3, or v7. This function does not define any other local variables, and so the LOCALS array only ever contains the input arguments.

If you recall the two variants ackermannA and ackermannB we saw earlier, both of them compile to the same bytecode:

  BYTECODE                                LOCALS          STACK
                                          |a0|a1|         |
   0: iload_1                             |a0|a1|         |a1|
   1: istore_2                            |a0|a1|a1|      |
   2: iload_0                             |a0|a1|a1|      |a0|
   3: istore_3                            |a0|a1|a1|a0|   |
   4: iload_3                             |a0|a1|a1|a0|   |a0|
   5: ifne          12                    |a0|a1|a1|a0|   |
   8: iload_2                             |a0|a1|a1|a0|   |a1|
   9: iconst_1                            |a0|a1|a1|a0|   |a1| 1|
  10: iadd                                |a0|a1|a1|a0|   |v1|
  11: ireturn                             |a0|a1|a1|a0|   |
  12: iload_2                             |a0|a1|a1|a0|   |a1|
  13: ifne          24                    |a0|a1|a1|a0|   |
  16: iload_3                             |a0|a1|a1|a0|   |a0|
  17: iconst_1                            |a0|a1|a1|a0|   |a0| 1|
  18: isub                                |a0|a1|a1|a0|   |v2|
  19: iconst_1                            |a0|a1|a1|a0|   |v2| 1|
  20: invokestatic ackermannA:(II)I       |a0|a1|a1|a0|   |v3|
  23: ireturn                             |a0|a1|a1|a0|   |
  24: iload_3                             |a0|a1|a1|a0|   |a0|
  25: iconst_1                            |a0|a1|a1|a0|   |a0| 1|
  26: isub                                |a0|a1|a1|a0|   |v4|
  27: iload_3                             |a0|a1|a1|a0|   |v4|a0|
  28: iload_2                             |a0|a1|a1|a0|   |v4|a0|a1|
  29: iconst_1                            |a0|a1|a1|a0|   |v4|a0|a1| 1|
  30: isub                                |a0|a1|a1|a0|   |v4|a0|v5|
  31: invokestatic ackermannA:(II)I       |a0|a1|a1|a0|   |v4|v6|
  34: invokestatic ackermannA:(II)I       |a0|a1|a1|a0|   |v7|
  37: ireturn                             |a0|a1|a1|a0|   |

Here, we can see that because the source code takes the two arguments and puts them in local variables, the bytecode has corresponding instructures to load the argument values from LOCALs 0 and 1 and store them back in index 2 and 3. However, bytecode does not care about the names of your local variables: they're treated purely as indices within the LOCALS array, meaning both ackermannA and ackermannB have identical bytecode. This makes sense, given that they are semantically equivalent!

However, ackermannA and ackermannB do not compile to the same bytecode as the original ackermann: while bytecode abstracts away the names of your local variables, it still does not fully abstract away the loads and stores to/from locals. We still find ourselves having to deal with the details of how values are moved around the LOCALS and STACK, even though they do not affect how the program actually behaves.

Apart from the lack of load/store abstractions, bytecode has another problem: like most linear assemblies, it is very much optimized for compactness, and can be quite difficult to modify when the time comes to perform optimizations.

To understand why, let us revisit the original bytecode for ackermann:

  BYTECODE                                LOCALS          STACK
                                          |a0|a1|         |
   0: iload_0                             |a0|a1|         |a0|
   1: ifne          8                     |a0|a1|         |
   4: iload_1                             |a0|a1|         |a1|
   5: iconst_1                            |a0|a1|         |a1| 1|
   6: iadd                                |a0|a1|         |v1|
   7: ireturn                             |a0|a1|         |
   8: iload_1                             |a0|a1|         |a1|
   9: ifne          20                    |a0|a1|         |
  12: iload_0                             |a0|a1|         |a0|
  13: iconst_1                            |a0|a1|         |a0| 1|
  14: isub                                |a0|a1|         |v2|
  15: iconst_1                            |a0|a1|         |v2| 1|
  16: invokestatic ackermann:(II)I        |a0|a1|         |v3|
  19: ireturn                             |a0|a1|         |
  20: iload_0                             |a0|a1|         |a0|
  21: iconst_1                            |a0|a1|         |a0| 1|
  22: isub                                |a0|a1|         |v4|
  23: iload_0                             |a0|a1|         |v4|a0|
  24: iload_1                             |a0|a1|         |v4|a0|a1|
  25: iconst_1                            |a0|a1|         |v4|a0|a1| 1|
  26: isub                                |a0|a1|         |v4|a0|v5|
  27: invokestatic ackermann:(II)I        |a0|a1|         |v4|v6|
  30: invokestatic ackermann:(II)I        |a0|a1|         |v7|
  33: ireturn                             |a0|a1|         |

And from here, make a strawman change: we have decided that the function call 30: invokestatic ackermann:(II)I does not make use of its first argument, and thus we can replace it with an equivalent function call 30: invokestatic ackermann2:(I)I that only takes 1 argument. This is a common optimization, and can allow any code that was used to compute the first argument of 30: invokestatic ackermann:(II)I to be eliminated by Dead Code Elimination we saw earlier.

To do so, we need to not only replace instruction 30, we need to look up the list of instructions to figure out where the first argument (v4 on the STACK) gets computed and remove those as well. We end up walking backwards, from instruction 30 to 22, and 22 to 21 and 20. The final change would look something like this:

  BYTECODE                                LOCALS          STACK
                                          |a0|a1|         |
   0: iload_0                             |a0|a1|         |a0|
   1: ifne          8                     |a0|a1|         |
   4: iload_1                             |a0|a1|         |a1|
   5: iconst_1                            |a0|a1|         |a1| 1|
   6: iadd                                |a0|a1|         |v1|
   7: ireturn                             |a0|a1|         |
   8: iload_1                             |a0|a1|         |a1|
   9: ifne          20                    |a0|a1|         |
  12: iload_0                             |a0|a1|         |a0|
  13: iconst_1                            |a0|a1|         |a0| 1|
  14: isub                                |a0|a1|         |v2|
  15: iconst_1                            |a0|a1|         |v2| 1|
  16: invokestatic ackermann:(II)I        |a0|a1|         |v3|
  19: ireturn                             |a0|a1|         |
- 20: iload_0                             |a0|a1|         |
- 21: iconst_1                            |a0|a1|         |
- 22: isub                                |a0|a1|         |
  23: iload_0                             |a0|a1|         |a0|
  24: iload_1                             |a0|a1|         |a0|a1|
  25: iconst_1                            |a0|a1|         |a0|a1| 1|
  26: isub                                |a0|a1|         |a0|v5|
  27: invokestatic ackermann:(II)I        |a0|a1|         |v6|
- 30: invokestatic ackermann:(II)I        |a0|a1|         |v7|
+ 30: invokestatic ackermann2:(I)I        |a0|a1|         |v7|
  33: ireturn                             |a0|a1|         |

While doing such a trivial change in the simple ackermann function is still possible, it gets much more messy with more involved changes in larger, real-world functions. In general, every small semantic change to your program may require in a large number of changes scattered throughout the bytecode.

You may have noticed that the way we perform this change is to look at the LOCALS and STACK values we have simulated on the right: seeing v4 arrives at instruction 30 from instruction 22, and instruction 22 takes in a0 and 1 which arrive from instructions 21 and 20. These values being passed around between LOCALS and STACK for a graph: from the instruction that computed each value, to the places where it is used.

Like the Ident/Assign pairs in our ASTs, the values being moved around the LOCALS and STACK form a graph between the computation of a value, and its usages. Why not work with the graph directly?

Dataflow Graphs

Dataflow Graphs are one level up the abstraction ladder from Bytecode or ASTs. If we extend our syntax tree with the Ident/Assign edges, or if we follow how the bytecode is moving values between LOCALS and STACK, we will end up forming the same graph. For the ackermann function, it looks like this:

G cluster_0 block0 cluster_1 block1 cluster_2 block2 cluster_3 block3 cluster_4 block4 m m ==_0 == m->==_0 -_3 - m->-_3 -_4_1 - m->-_4_1 ackermann_4_2 ackermann m->ackermann_4_2 if_0 if ==_0->if_0 1_1 1 if_0->1_1 0_2 0 if_0->0_2 +_1 + 1_1->+_1 return_1 return +_1->return_1 ==_2 == 0_2->==_2 if_2 if ==_2->if_2 1_3_1 1 if_2->1_3_1 1_4_1 1 if_2->1_4_1 1_3_1->-_3 ackermann_3 ackermann -_3->ackermann_3 return_3 return ackermann_3->return_3 1_4_1->-_4_1 ackermann_4_1 ackermann -_4_1->ackermann_4_1 return_4 return ackermann_4_1->return_4 ackermann_4_2->ackermann_4_1 0_0 0 0_0->==_0 n n n->+_1 n->==_2 -_4_2 - n->-_4_2 -_4_2->ackermann_4_2 1_3_2 1 1_3_2->ackermann_3 1_4_2 1 1_4_2->-_4_2

Unlike ASTs or Java's Stack/Register-based bytecode, Dataflow Graphs do not have the concept of a "local variable": instead the graph contains directed edges between each value and its usages. While analyzing bytecode often involves doing abstract interpretation of the LOCALS and STACK to figure out how values are being moved around, and analyzing ASTs involves walking both the tree and managing a symbol table containing Assign/Ident edges, analyzing Dataflow Graphs is often just straightforward graph traversals: the whole idea of "moving a value" removed from the program representation.

Dataflow graphs are also somewhat easier to manipulate than linear bytecode: replacing a function call node ackermann with an ackermann2 call, and discarding one of its arguments, is a matter of mutating a graph node (below in green) and removing one input edge, along with its transitive upstream edges/nodes, from the graph (below in red):

G cluster_0 block0 cluster_1 block1 cluster_2 block2 cluster_3 block3 cluster_4 block4 m m ==_0 == m->==_0 -_3 - m->-_3 -_4_1 - m->-_4_1 ackermann_4_2 ackermann m->ackermann_4_2 if_0 if ==_0->if_0 1_1 1 if_0->1_1 0_2 0 if_0->0_2 +_1 + 1_1->+_1 return_1 return +_1->return_1 ==_2 == 0_2->==_2 if_2 if ==_2->if_2 1_3_1 1 if_2->1_3_1 1_4_1 1 if_2->1_4_1 1_3_1->-_3 ackermann_3 ackermann -_3->ackermann_3 return_3 return ackermann_3->return_3 1_4_1->-_4_1 ackermann_4_1 ackermann2 -_4_1->ackermann_4_1 return_4 return ackermann_4_1->return_4 ackermann_4_2->ackermann_4_1 0_0 0 0_0->==_0 n n n->+_1 n->==_2 -_4_2 - n->-_4_2 -_4_2->ackermann_4_2 1_3_2 1 1_3_2->ackermann_3 1_4_2 1 1_4_2->-_4_2

Thus, we can see how a small change to the program (replacing ackermann(int n, int m) with ackermann2(int m)) ends up being a relatively localized change in the dataflow graph.

In general, dataflow graphs are much easier to work with than linear bytecodes or ASTs: both graph analyses and graph modifications are straightforward.

This description of dataflow graphs leaves out a lot of detail: apart from the actual physical representation of the graph, there are also many different ways of modeling state, control flow, whose handling would be a bit more involved and not fit within this blog post. We have also left out the details of graph transformations, e.g. adding/removing edges, forward/backwards breadth/depth-first traversals, etc. which should be familiar to anyone who has taken an introductory algorithms course.

Lastly, we have skipped over the algorithms for converting from a linear bytecode to the dataflow graph, and then from the dataflow graph back to a linear bytecode. That is itself an interesting problem, but for now will have to be left as an exercise for the reader.

Inference

Once you have a representation of your program, the next step is inference: finding out facts about your program so you can know how to transform it without changing its behavior. Indeed, many of the optimizations we saw earlier are based on inferring facts about our program:

The more facts we can infer about a program, and the more specific those facts can be, the more aggressively we can transform the program to optimize it without changing its semantics. We will discuss the inference of types, constants, liveness and reachability in this post, and leave analysis of Purity as an exercise to the reader.

Types, constants, purity, liveness are just some of the most common things an optimizer might want to infer about your program, and the way inference works in the optimizer is a very similar topic to inference of types in a compiler's frontend typechecker.

The Inference Lattice

To begin with, let us consider how inferring types and constants would work. At its core, a "type" represents something we know about a value:

The potential types a value in your program could take form a lattice:

G Any Any Integer Integer Any->Integer CharSequence CharSequence Any->CharSequence Array[Float] Array[Float] Any->Array[Float] String String CharSequence->String StringBuilder StringBuilder CharSequence->StringBuilder

Inferring constant value is very similar to inferring types: in a way, a constant string "hello" is a sub-type of String, just like String is a subtype of CharSequence. Since there is only one string "hello", these are often called "Singleton Types". Extending our lattice to include these is straightforward.

G Any Any Integer Integer Any->Integer CharSequence CharSequence Any->CharSequence Array[Float] Array[Float] Any->Array[Float] 0 0 Integer->0 1 1 Integer->1 2 2 Integer->2 ... ... Integer->... String String CharSequence->String StringBuilder StringBuilder CharSequence->StringBuilder "hello" "hello" String->"hello" "world" "world" String->"world" ...2 ... String->...2

The further up the lattice the type we assign to a value, the less we know about it. The further down the lattice, the more we know. If we know a value has one of two types, e.g. it is either a String OR a StringBuilder then we can conservatively model it as the least-upper-bound type of the two: a CharSequence. If we have a value that is either 0 or 1 or 2, we infer it to be an Integer.

Note that exactly how fine-grained you want to make your lattice is a design decision. In this case we decided the parent of 0, 1, 2 and 3 are all Integer:

G Integer Integer 0 0 Integer->0 1 1 Integer->1 2 2 Integer->2 3 3 Integer->3 ... ... Integer->...

But we could come up with an more fine-grained lattice, e.g. separating even and odd numbers:

G Integer Integer EvenInteger EvenInteger Integer->EvenInteger OddInteger OddInteger Integer->OddInteger 0 0 EvenInteger->0 2 2 EvenInteger->2 ...2 ... EvenInteger->...2 1 1 OddInteger->1 3 3 OddInteger->3 ...1 ... OddInteger->...1

The more fine-grained your lattice, the more precise your inferences, but the longer your analysis will take to run. Exactly how fine-grained you want it to be ends up being a judgement call.

Inferring Count

Let us look at a simplified version of the main program we saw earlier:

static int main(int n){
  int count = 0, multiplied = 0;
  while(count < n){
    if (multiplied < 100) logger.log(count);
    count += 1;
    multiplied *= count;
  }
  return ...;
}

For simplicity have removed the various calls to ackermann, to only leave the code using count, multiplied, and logger. First, let's look at the dataflow graph:

G cluster_0 block0 cluster_00 block1 cluster_00b block1b cluster_00c block1c cluster_1 block2 cluster_2 block3 count count <_0 < count-><_0 + + count->+ * * count->* if_0 if <_0->if_0 100 100 if_0->100 return return if_0->return <_b < 100-><_b if_b if <_b->if_b logger.log logger.log if_b->logger.log 1 1 if_b->1 logger.log->1 1->+ +->count n n *->n multiplied multiplied *->multiplied n-><_0 multiplied-><_b multiplied->* 0 0 0->count 0->n 0_m 0 0_m->multiplied

We see that count isinitialized to 0 in block0. Next, control flow shifts to block1, where we check if count < n: if it is we go to block3 and return, otherwise we go to block2 which increments count by 1 and stores it back in count, returning control to block1 to repeat the check. This loops until the conditional < returns false, at which point it goes to block3 and exits.

How do we analyze this?

Inferring Multiplied

Let us consider the other dataflow graph, regarding multiplied, block by block:

Since multiplied is inferred to be 0, we now know:

These modifications are marked in red below:

G cluster_0 block0 cluster_00 block1 cluster_00b block1b cluster_00c block1c cluster_1 block2 cluster_2 block3 count count <_0 < count-><_0 + + count->+ * * count->* if_0 if <_0->if_0 100 100 if_0->100 return return if_0->return <_b < 100-><_b if_b if <_b->if_b logger.log logger.log if_b->logger.log 1 1 if_b->1 logger.log->1 1->+ +->count n n *->n multiplied multiplied *->multiplied n-><_0 multiplied-><_b multiplied->* 0 0 0->count 0->n 0_m 0 0_m->multiplied

This gives us the following remaining dataflow graph:

G cluster_0 block0 cluster_00 block1 cluster_00b block1b cluster_00c block1c cluster_1 block2 cluster_2 block3 count count <_0 < count-><_0 + + count->+ if_0 if <_0->if_0 if_b if_0->if_b return return if_0->return logger.log logger.log if_b->logger.log 1 1 logger.log->1 1->+ +->count n n +->n n-><_0 0 0 0->count 0->n

Serializing this back to bytecode gives us the following program:

static int main(int n){
  int count = 0;
  while(count < n){
    logger.log(count);
    count += 1;
  }
  return ...;
}

Thus, we can see that by modeling the program with an appropriate data structure, a simple set of graph traversals is enough to simplify the messy, inefficient code we saw earlier to the tight, efficient loop we see here.

The inference that multiplied -> 0 can also feed into other optimizations, e.g. the Partial Evaluation step we saw earlier. In general, inferences allow further inferences, and optimizations allow further optimizations. Much of the difficulty in designing a program optimizer is to ensure you can leverage this sort of downstream opportunities for inference and optimization in an effective and efficient manner.


In this section, I have shown how basic inference works on a dataflow graph. we have seen:

In this case, we could perform the analysis for count first and multiplied separately after. This is only possible because multiplied depends on count, but count does not depend onmultiplied`. If there is a mutual dependency, we can still do the analysis, but must analyze both variables together in one traversal of the dataflow graph.

Note that this inference is an iterative process: you keep going around loops until your inferences stop changing. In general, program dataflow graphs without cycles (whether loops or recursion) can always be analyzed in one pass, whereas programs with loops or recursion will need iteration for the inferences to converge.

While in this case we only needed to make two rounds around the while loop, it is possible to come up with programs that take O(height-of-lattice) iterations to analyze. Thus, while a more fine-grained lattice (e.g. the one separating EvenIntegers and OddIntegers we saw earlier) could improve the precision of analysis, it also increases the time it takes to run, and exactly where you make that trade-off depends on the use case.

A similar technique can be also used to infer properties of arbitrary recursive functions, as we will see in a moment.

Inter-procedural Function Inference

Above, we performed inference of values within a single function body, ignoring for simplicity calls to other functions or methods. However, most programs are made of many functions calling each other, so working with real programs forces you to infer properties not just about single expressions within a function, but also about the functions themselves and how they relate to each other in a call graph.

Simple Non-Recursive Functions

In most cases, handling other functions is easy. consider this tiny example program:

static int main(int n){
    return called(n, 0);
}

static int called(int x, int y){
    return x * y;
}

The datflow graph for these two functions looks something like this:

G cluster_0 main cluster_called called n n called called n->called return_main return called->return_main y y called->y * * y->* return_called return *->return_called return_called->called 0 0 0->called x x x->*

Inferring the return value of main proceeds as follows:

Essentially, when you reach a function call during inference, you simply perform inference on the callee before returning to continue inference on the caller. This stack of inferences can go arbitrarily deep, and mirrors the call stack that will be present when your program is running.

Since we now know called(n, 0) is 0, we can use that knowledge to simplify the dataflow graph:

G cluster_0 main 0 0 return_main return 0->return_main n n

Which can be serialized back to the following code:

static int main(int n){
    return 0;
}

This works well as long as your functions are not recursive: if function A calls B calls C calls D, then D returns its inference to C, B, D and A. However, if function A calls B and B calls A, or even function A just calls A recursively, this falls apart as the calls will never return!

A Recursive Factorial Function

Let's consider a simple recursive function, written in pseudo-Java:

public static Any factorial(int n) {
    if (n == 1) {
        return 1;
    } else {
        return n * factorial(n - 1);
    }
}

It takes and int n, but its return type is labeled Any: the declaration does not state what it returns. Visually, we can clearly see that factorial returns an int (or Integer in our lattice), but if we try to fully analyze factorial's body before returning, we will try to analyze the recursive call to factorial before our own analysis of factorial completes, leaving us in an infinite recursion! How can our inference engine determine that on our behalf?

Inference with Bottom

The solution to this is to extend our inference lattice with the special value Bottom:

G Any Any Integer Integer Any->Integer CharSequence CharSequence Any->CharSequence Array[Float] Array[Float] Any->Array[Float] 0 0 Integer->0 1 1 Integer->1 2 2 Integer->2 ... ... Integer->... Bottom Bottom 0->Bottom 1->Bottom 2->Bottom ...->Bottom String String CharSequence->String StringBuilder StringBuilder CharSequence->StringBuilder "hello" "hello" String->"hello" "world" "world" String->"world" ...2 ... String->...2 "hello"->Bottom "world"->Bottom ...2->Bottom StringBuilder->Bottom Array[Float]->Bottom

The purpose of Bottom is to say "we don't know what this is yet, but we will come back later and fill it out". We can use Bottom on the factorial dataflow graph as follows:

G cluster_factorial cluster_0 block0 cluster_true true cluster_false false n n == == n->== - - n->- * * n->* if if ==->if 1_2 1 if->1_2 1_3 1 if->1_3 return_true return 1_2->return_true 1_3->- factorial factorial -->factorial block_factorial factorial factorial->block_factorial factorial->* block_factorial->factorial return_false return *->return_false 1_1 1 1_1->==

Thus, we can infer that factorial returns an Integer, despite it being a recursive function without a declared return type.


While these examples are a bit contrived, it serves to illustrate the way inference can proceed in the presence of function calls. Even for recursive functions, you stub out the recursive calls with Bottom when you first see them, and when the first analysis pass completes, you replace the stub with the first pass inference and propagate the change through the dataflow graph.

In this case, the * expression is analyzed three times:

As is the case for inferring the loop constant multiplied earlier, this might happen up to O(height-of-lattice) times before the inference converges. This approach generalizes to other styles of recursive functions, and to inferring other properties such as purity.

Inferring Liveness and Reachability

Liveness and Reachability analyses are a pair of optimizations that generally fall under the category "dead code elimination". These analyses find all the code whose values contribute to the final returned result ("Live") and all code which the the control-flow of the program can reach ("Reachable"): code that fails either test is a safe candidate for removal.

To illustrate these two analyses, let us consider another simplified version of our initial main function:

static int main(int n){
  int count = 0, total = 0, multiplied = 0;
  while(count < n){
    if (multiplied > 100) count += 1;
    count += 1;
    multiplied *= 2;
    total += 1;
  }
  return count;
}

Unlike our earlier versions, here we flipped the conditional around in if (multiplied > 100), and have changed multiplied *= count to multiplied *= 2 to simplify the program graphs without loss of generality.

Traversing Backwards for Liveness

There are two issues we can identify with the program as written above:

Ideally, this should be written as:

static int main(int n){
  int count = 0;
  while(count < n){
    count += 1;
  }
  return count;
}

Let us now see how a program analyzer can perform this rewrite automatically.

First, let's take a look at the dataflow graph:

G cluster_0 block0 cluster_00 block1 cluster_00b block1b cluster_00c block1c cluster_1 block2 cluster_2 block3 count count <_0 < count-><_0 + + count->+ +_c + count->+_c return return count->return if_0 if <_0->if_0 100 100 if_0->100 block2 if_0->block2 > > 100->> if_1 if >->if_1 1 1 if_1->1 1_c 1 if_1->1_c 1->+ +->count n n +->n n-><_0 1_c->+_c +_c->count multiplied multiplied multiplied->> * * multiplied->* *->multiplied total total +_total + total->+_total +_total->total 0_1 0 0_1->count 0_1->n 0_2 0 0_2->multiplied 0_3 0 0_3->total 2 2 2->* 1_total 1 1_total->+_total block1

This is a bit messier than the graphs we have seen earlier, but should still be readable: you start off at block0, go to block1, if the conditional is true block1b, if that conditional is true block1c, and so on, eventually exiting at block3's return node.

Removing the Dead and Unreachable

We can run the same type/constant inference we did earlier to find multiplied -> 0, and modify the graph accordingly:

G cluster_0 block0 cluster_00 block1 cluster_00b block1b cluster_00c block1c cluster_1 block2 cluster_2 block3 count count <_0 < count-><_0 + + count->+ +_c + count->+_c return return count->return if_0 if <_0->if_0 100 100 if_0->100 block2 if_0->block2 > > 100->> if_1 if >->if_1 1 1 if_1->1 1_c 1 if_1->1_c 1->+ +->count n n +->n n-><_0 1_c->+_c +_c->count multiplied multiplied multiplied->> * * multiplied->* *->multiplied total total +_total + total->+_total +_total->total 0_1 0 0_1->count 0_1->n 0_2 0 0_2->multiplied 0_3 0 0_3->total 2 2 2->* 1_total 1 1_total->+_total block1

This results in the following dataflow graph:

G cluster_0 block0 cluster_00 block1 cluster_00b block1b cluster_00c block1c cluster_1 block2 cluster_2 block3 count count <_0 < count-><_0 + + count->+ +_c + count->+_c return return count->return if_0 if <_0->if_0 0 0 if_0->0 block2 if_0->block2 > > 0->> if_1 if >->if_1 1 1 if_1->1 1_c 1 if_1->1_c 1->+ +->count n n +->n n-><_0 1_c->+_c +_c->count total total +_total + total->+_total +_total->total 0_1 0 0_1->count 0_1->n 0_3 0 0_3->total 100 100 100->> 1_total 1 1_total->+_total

Now, we can clearly see that the conditional in block1b (0 > 100) will never evaluate to true. That means the false branch of that conditional block1c is unreachable, and can be discarded (and the if-conditional eliminated):

G cluster_0 block0 cluster_00 block1 cluster_00b block1b cluster_1 block2 cluster_2 block3 count count <_0 < count-><_0 + + count->+ return return count->return if_0 if <_0->if_0 block00b if_0->block00b block2 if_0->block2 block1 block00b->block1 block00 block1->block00 +->count total total +_total + total->+_total +_total->total 0_1 0 0_1->count 0_3 0 0_3->total n n n-><_0 100 100 > > 100->> 0 0 0->> 1 1 1->+ 1_total 1 1_total->+_total block0 block0->block00

Lastly, we can see that we have a number of "dead-end nodes": nodes like total and > which compute some value, but are not themselves used in any downstream conditional, return or computation. These can be discovered by traversing the dataflow graph backwards starting from the return node, collecting all the live nodes and eliminating the rest: the >, 100, 0 in block1b, total, the 0 and + 1 feeding into total from block0 and block2. This leaves you with the following:

G cluster_0 block0 cluster_00 block1 cluster_00b block1b cluster_1 block2 cluster_2 block3 count count <_0 < count-><_0 + + count->+ return return count->return if_0 if <_0->if_0 block00b if_0->block00b if_0->return 1 1 block00b->1 1->+ +->count n n +->n n-><_0 0_1 0 0_1->count 0_1->n

Which when serialized back to bytecode, would produce the "ideal" program we saw earlier:

static int main(int n){
  int count = 0;
  while(count < n){
    count += 1;
  }
  return count;
}

Conclusion

In this post, we have gone through how program inference can be performed inside an optimizing compiler:

Through this, we have seen how it is possible to turn a manual process of program optimization performed by a human into an automated process, running simple algorithms over simple datastructures.

This is only a small taste of the kinds of work that an optimizing compiler does to make your code fast. There is a lot this post cannot cover. For further reading, consider the following papers and books as good starting points:


About the Author: Haoyi is a software engineer, an early contributor to Scala.js, and the author of many open-source Scala tools such as the Ammonite REPL and FastParse.

If you've enjoyed this blog, or enjoyed using Haoyi's other open source libraries, please chip in (or get your Company to chip in!) via Patreon so he can continue his open-source work


Working with Databases using Scala and Quill

Updated 2019-11-09