Compiling Pascal with LLVM: Part 4
Compilation
It took several hundreds of lines of code, but we're finally there! Today our compiler will come to life and finally compute something for us.
A working prototype
Let's start small and build a compiler that works with something super dumb:
program add;
begin
1 + 1;
end.
We'll start with that, and work our way up. Here's our compiler class:
from llvmlite import ir
class Compiler(Visitor):
def __init__(self, ts: TypeSystem):
self.module = ir.Module()
self._builders = []
self._ts = ts
self._references = ts.references
# for allocated variables
self._allocas = {}
# for functions and strings deduplication
self._function_names = {}
self._string_idx = 0
As always, let's visit
the program first:
def _program(self, node: Program):
main = ir.Function(self.module, ir.FunctionType(ir.VoidType(), ()), '.main')
with self._enter(main):
self.visit_sequence(node.body)
self.builder.ret_void()
Here _enter
is more ore less the same as in our type system:
@contextmanager
def _enter(self, func):
self._builders.append(ir.IRBuilder(func.append_basic_block()))
yield
self._builders.pop()
and self.builder
is just a property:
@property
def builder(self) -> ir.IRBuilder:
return self._builders[-1]
In LLVM each function wants a separate IRBuilder, think of it as a storage for all the statements in the function's body.
Even though there isn't a main
function in Pascal as in C
, LLVM can't just run stuff "in the global scope", so we
have to create here a .main
function. I chose this name, because it's not a valid function name in Pascal, so we
won't have name collisions.
For now this function is pretty simple. It has no arguments, neither variables, it just goes through the statements
and return Void
at the end.
Expression statement
As I quickly mentioned in the previous post, we'll expect from lvalues
to return a pointer. This will be the
signature of methods that deal with values:
def _my_node(self, node, lvalue: bool):
# ...
Now we don't need the "expected
" argument, because all the information regarding types is already present.
So, for expression statements the code is fairly simple:
def _expression_statement(self, node: ExpressionStatement):
self.visit(node.value, lvalue=False)
Const
Let's keep digging through the simple stuff first. Here's how we evaluate integer and floating constants:
def _const(self, node: Const, lvalue: bool):
value = node.value
match node.type:
case types.SignedInt(_) | types.Floating(_) as kind:
return ir.Constant(resolve(kind), value)
# ... more cases later
raise ValueError(value)
Here resolve
is the bridge between our own type system and the one from LLVM:
def resolve(kind):
match kind:
case types.Void:
return ir.VoidType()
case types.Char:
return ir.IntType(8)
case types.SignedInt(bits):
return ir.IntType(bits)
case types.Floating(64):
return ir.DoubleType()
case types.Reference(kind) | types.Pointer(kind) | types.DynamicArray(kind):
return ir.PointerType(resolve(kind))
case types.StaticArray(dims, kind):
size = reduce(mul, [b - a for a, b in dims], 1)
return ir.ArrayType(resolve(kind), size)
case types.Record(fields):
return ir.LiteralStructType([resolve(field.type) for field in fields])
raise ValueError(kind)
As we can see, LLVM's type system is much simpler, e.g. arrays are only 1-dimensional and start always at 0 also there are no references or dynamic arrays, just pointers.
Binary
Once again, binary operators are a headache:
def _binary(self, node: Binary, lvalue: bool):
left = self.visit(node.left, lvalue)
right = self.visit(node.right, lvalue)
kind = self._type(node.left)
right_kind = self._type(node.right)
assert kind == right_kind, (kind, right_kind)
match kind:
case types.SignedInt(_):
if node.op in COMPARISON:
return self.builder.icmp_signed(COMPARISON[node.op], left, right)
return {
'+': self.builder.add,
'-': self.builder.sub,
'*': self.builder.mul,
'/': self.builder.sdiv,
}[node.op](left, right)
case types.Floating(_):
if node.op in COMPARISON:
return self.builder.fcmp_ordered(COMPARISON[node.op], left, right)
return {
'+': self.builder.fadd,
'-': self.builder.fsub,
'*': self.builder.fmul,
'/': self.builder.fdiv,
}[node.op](left, right)
case types.Boolean:
return {
'and': self.builder.and_,
'or': self.builder.or_,
}[node.op](left, right)
case x:
raise TypeError(x)
We visit the left and right operand, and make sure they have the same type, which must be true after type casting. Here's how we get the type of a node:
def _type(self, node):
if node in self._ts.casting:
return self._ts.casting[node]
return self._ts.types[node]
Quick reminder: self._ts is an instance of the TypeSystem class, which already analyzed our program.
Then, based on the type we choose from the multitude of LLVM's operators. Logical operators get special treatment, because of how they are represented in Pascal:
COMPARISON = {
'<': '<',
'<=': '<=',
'>': '>',
'>=': '>=',
'=': '==',
'<>': '!=',
}
We need this mapping to simplify the conversion from <>
to the modern globally accepted !=
.
Running the code
Now we can finally run our stupid program! We'll need to set a few things up first:
import ctypes
import llvmlite.binding as llvm
from pascal_llvm.compiler import Compiler
from pascal_llvm.parser import Parser
from pascal_llvm.tokenizer import tokenize
from pascal_llvm.type_system import TypeSystem
source = '''
program add;
begin
1 + 1;
end.
'''
# scan and parse
tokens = tokenize(source)
parser = Parser(tokens)
program = parser._program()
# add types
ts = TypeSystem()
ts.visit(program)
# compile
compiler = Compiler(ts)
compiler.visit(program)
module = compiler.module
# translate
module = llvm.parse_assembly(str(module))
module.verify()
# init
llvm.initialize()
llvm.initialize_native_target()
llvm.initialize_native_asmprinter()
target = llvm.Target.from_default_triple()
machine = target.create_target_machine()
engine = llvm.create_mcjit_compiler(llvm.parse_assembly(""), machine)
# load the code
engine.add_module(module)
engine.finalize_object()
engine.run_static_constructors()
# get the ".main" function pointer
main = ctypes.CFUNCTYPE(None)(engine.get_function_address('.main'))
# call it
main()
That's a lot of glue code. But it's reusable! From now on you can simply
change the source
and play around with the compiler as it gets smarter.
At this point this will do basically nothing, but we could print
LLVM's IR
and see what it thinks of our code. Just call print(module)
:
; ModuleID = ""
target triple = "unknown-unknown-unknown"
target datalayout = ""
define void @".main"()
{
.2:
%".3" = add i8 1, 1
ret void
}
Even if you're not familiar with the syntax, it's pretty clear what's happening here - a function definition, which adds two i8 (char) constants, stores them in a temporary variable and returns nothing. Just as we intended!
If you expected LLVM to optimize out this useless add
operation, you're totally right. We'll get to that later.
Variables
Our next milestone is adding variables and tons of code that works with them. First of all, let's define them:
def _program(self, node: Program):
for definitions in node.variables:
for name in definitions.names:
var = ir.GlobalVariable(self.module, resolve(definitions.type), name=name.normalized)
var.linkage = 'private'
self._allocas[name] = var
main = ir.Function(self.module, ir.FunctionType(ir.VoidType(), ()), '.main')
with self._enter(main):
self.visit_sequence(node.body)
self.builder.ret_void()
Nothing special, we create a global variable of a given type and store it in our _allocas
dict. I'm not very familiar
with use-cases for different linkage types, but here "private"
is the way to go - this is basically a global variable private to this module.
Name
Now that we can define variables, let's learn to get their values and pointers.
Remember how in the previous post we resolved each Name
node, and found out what it references? It's time to use this
info:
def _name(self, node: Name, lvalue: bool):
target = self._references[node]
ptr = self._allocas[target]
if lvalue:
if isinstance(self._type(target), types.Reference):
ptr = self.builder.load(ptr)
return ptr
return self.builder.load(ptr)
We get the right pointer first, then dereference it if it's an rvalue
, otherwise - just return the pointer. A special
case are references, which are represented as pointers: when we write
procedure f(var x: integer);
begin
x := 1;
end;
we pretend x is an integer, while under the hood it's a pointer. So, in _allocas
we'll store
a pointer to a pointer, that's why we need an additional dereferencing step.
Dereference
Speaking of dereferencing:
def _dereference(self, node: Dereference, lvalue: bool):
# always expect a pointer, so lvalue=True
ptr = self.builder.load(self.visit(node.target, lvalue=True))
if lvalue:
return ptr
return self.builder.load(ptr)
More or less the same stuff.
Unary
While we're at it, the opposite operation - taking the address - is another unary operator:
def _unary(self, node: Unary, lvalue: bool):
# getting the address is a special case
if node.op == '@':
# just get the name's address
return self.visit(node.value, lvalue=True)
value = self.visit(node.value, lvalue)
match node.op:
case '-':
return self.builder.neg(value)
case 'not':
return self.builder.not_(value)
case x:
raise ValueError(x, node)
@
is, once again, a special case.
GetField
LLVM's notion of structs is pretty simple, and is basically indistinguishable from tuples.
record
count: integer;
percentage: real;
end;
This record will have the type (i32, double)
. This means that we can't access fields by name, and must use integer
indices instead:
def _get_field(self, node: GetField, lvalue: bool):
ptr = self.visit(node.target, lvalue=True)
kind = self._type(node.target)
if isinstance(kind, types.Reference):
kind = kind.type
idx, = [i for i, field in enumerate(kind.fields) if field.name == node.name]
ptr = self.builder.gep(
ptr, [ir.Constant(ir.IntType(32), 0), ir.Constant(ir.IntType(32), idx)]
)
if lvalue:
return ptr
return self.builder.load(ptr)
We then use the gep
(get element pointer) command to take the value at the right index in the "tuple".
I have no idea why we must pass a list with ir.Constant(ir.IntType(32), 0)
as the first element, but all the
examples I found do it this way. If you can enlighten me, please leave a comment at the bottom of the page
GetItem
As we saw earlier, in the LLVM world there are only 1D arrays that start from 0. That's ok, we just need to figure out a way to flatten Pascal's arrays. There are many ways to do that, we'll use strides here, more or less in the same fashion as numpy does:
def _get_item(self, node: GetItem, lvalue: bool):
# we always want a pointer from the parent
ptr = self.visit(node.target, lvalue=True)
stride = 1
dims = self._type(node.target).dims
idx = ir.Constant(ir.IntType(32), 0)
for (start, stop), arg in reversed(list(zip(dims, node.args, strict=True))):
local = self.visit(arg, lvalue=False)
# upcast to i32
local = self._cast(local, self._type(arg), types.Integer, False)
# extract the origin
local = self.builder.sub(local, ir.Constant(ir.IntType(32), start))
# multiply by stride
local = self.builder.mul(local, ir.Constant(ir.IntType(32), stride))
# add to index
idx = self.builder.add(idx, local)
stride *= stop - start
ptr = self.builder.gep(ptr, [ir.Constant(ir.IntType(32), 0), idx])
if lvalue:
return ptr
return self.builder.load(ptr)
That's a lot of code, but basically this is what it does:
llvm_index = stride1 * (pascal_index1 - index_start1) + stride2 * (pascal_index2 - index_start2) + ...
This StackOverflow question has a great visualization of what's going on.
Assignment
The final node that will allow us to modify the variables we just created is the assignment:
def _assignment(self, node: Assignment):
ptr = self.visit(node.target, lvalue=True)
value = self.visit(node.value, lvalue=False)
if isinstance(self._type(node.value), types.Reference):
value = self.builder.load(value)
self.builder.store(value, ptr)
Also pretty straightforward: get the value from the right, get the address from the left, store one in the other.
Statements
There aren't a lot of statements in Pascal, so let's dig through them quickly:
If
LLVM reasons in blocks. A block is just a sequence of simple operations: no branching, no control flow or jumps. We just start with the first operation and finish with the last one.
If we need to create a branch, though, we just do it between blocks:
def _if(self, node: If):
condition = self.visit(node.condition, lvalue=False)
then_block = self.builder.append_basic_block()
else_block = self.builder.append_basic_block()
merged_block = self.builder.append_basic_block()
self.builder.cbranch(condition, then_block, else_block)
# then
self.builder.position_at_end(then_block)
self.visit_sequence(node.then_)
self.builder.branch(merged_block)
# else
self.builder.position_at_end(else_block)
self.visit_sequence(node.else_)
self.builder.branch(merged_block)
# phi
self.builder.position_at_end(merged_block)
So here's what's going on. We create 3 blocks: two for if
's branches and one for the final block. We create a
conditional branch first, this will let us jump to one of the blocks, then we compile each branch. Note that
at the beginning of each branch we enter the corresponding block with position_at_end
, and at the end we create an
unconditional branch
(basically a jump) to the final block.
That's more or less how all control flow will be implemented.
While
While
is, in a way, even simpler than If
:
def _while(self, node: While):
check_block = self.builder.append_basic_block('check')
loop_block = self.builder.append_basic_block('for')
end_block = self.builder.append_basic_block('for-end')
self.builder.branch(check_block)
# check
self.builder.position_at_end(check_block)
condition = self.visit(node.condition, False)
self.builder.cbranch(condition, loop_block, end_block)
# loop
self.builder.position_at_end(loop_block)
self.visit_sequence(node.body)
self.builder.branch(check_block)
# exit
self.builder.position_at_end(end_block)
At each iteration we check the condition and make a conditional jump. Inside the loop we just compile the body and unconditionally jump back to the first block. This way we make create a loop. Easy!
For
As always, For
is just a combination of While
and some increments.
It makes you think, maybe we should have desugared this node in the previous post
def _for(self, node: For):
name = node.name
start = self.visit(node.start, lvalue=False)
stop = self.visit(node.stop, lvalue=False)
self._assign(name, start)
check_block = self.builder.append_basic_block('check')
loop_block = self.builder.append_basic_block('for')
end_block = self.builder.append_basic_block('for-end')
self.builder.branch(check_block)
# check
self.builder.position_at_end(check_block)
counter = self.visit(name, lvalue=False)
condition = self.builder.icmp_signed('<=', counter, stop, 'for-condition')
self.builder.cbranch(condition, loop_block, end_block)
# loop
self.builder.position_at_end(loop_block)
self.visit_sequence(node.body)
# update
increment = self.builder.add(counter, ir.Constant(resolve(self._type(name)), 1), 'increment')
self._assign(name, increment)
self.builder.branch(check_block)
# exit
self.builder.position_at_end(end_block)
I guess only two pieces are interesting here:
start = self.visit(node.start, lvalue=False)
self._assign(name, start)
# and
increment = self.builder.add(counter, ir.Constant(resolve(self._type(name)), 1), 'increment')
self._assign(name, increment)
_assign
is almost identical to _assignment
:
def _assign(self, name: Name, value):
target = self._references[name]
ptr = self._allocas[target]
if isinstance(self._type(target), types.Reference):
ptr = self.builder.load(ptr)
self.builder.store(value, ptr)
As we can see, we're writing to the same variable twice, but LLVM is cool with that.
Functions and calls
We're almost there!
Function definitions
We'll define all the functions after defining the global variables but before the main body:
def _program(self, node: Program):
for definitions in node.variables:
for name in definitions.names:
var = ir.GlobalVariable(self.module, resolve(definitions.type), name=name.normalized)
var.linkage = 'private'
self._allocas[name] = var
# --- new stuff starts here ---
for func in node.functions:
ir.Function(
self.module, ir.FunctionType(resolve(func.return_type), [resolve(arg.type) for arg in func.args]),
self._deduplicate(func),
)
self.visit_sequence(node.functions)
main = ir.Function(self.module, ir.FunctionType(ir.VoidType(), ()), '.main')
with self._enter(main):
self.visit_sequence(node.body)
self.builder.ret_void()
Just like before, we first define all the functions, and only after that we visit each of their bodies.
Here _deduplicate
helps us... uh... deduplicate the function names. Remember that we might have several overloaded
functions, LLVM can't digest that, so we need to help it a bit:
def _deduplicate(self, node: Function):
if node not in self._function_names:
self._function_names[node] = f'function.{len(self._function_names)}.{node.name.name}'
return self._function_names[node]
We create a new unique name for each function. It's a pretty dumb strategy, I admit, but it's dead simple (that's a good thing!). If I was writing a real compiler, I would add some info regarding argument names and the return value instead, this would help a lot with debugging.
Function
Now to the function itself. It's more or less the same thing as with the Program
:
def _function(self, node: Function):
ret = node.name
func = self.module.get_global(self._deduplicate(node))
with self._enter(func):
if node.return_type != types.Void:
self._allocate(ret, resolve(node.return_type))
for arg, param in zip(func.args, node.args, strict=True):
name = param.name
arg.name = name.normalized
self._allocate(name, resolve(param.type), arg)
for definitions in node.variables:
for name in definitions.names:
self._allocate(name, resolve(definitions.type))
self.visit_sequence(node.body)
if node.return_type != types.Void:
self.builder.ret(self.builder.load(self._allocas[ret]))
else:
self.builder.ret_void()
- Get the function
- Enter the scope
- Define the variables, the arguments and the return value, if any
- Visit the body
- Return
The only difference is that we're not defining functions here, because Pascal doesn't support closures.
We define local variables like so:
def _allocate(self, name: Name, kind: ir.Type, initial=None):
self._allocas[name] = self.builder.alloca(kind, name=name.normalized)
if initial is not None:
self.builder.store(initial, self._allocas[name])
alloca
is LLVM's way to allocate a portion of memory.
Call
We can define functions. Time to learn how to call them:
def _call(self, node: Call, lvalue: bool):
magic = MAGIC_FUNCTIONS.get(node.target.normalized)
if magic is not None:
return magic.evaluate(node.args, list(map(self._type, node.args)), self)
target = self._references[node.target]
func = self.module.get_global(self._deduplicate(target))
signature = target.signature
args = []
for arg, kind in zip(node.args, signature.args, strict=True):
if isinstance(kind, types.Reference):
value = self._allocas[self._references[arg]]
else:
value = self.visit(arg, lvalue=False)
args.append(value)
return self.builder.call(func, args)
If we ignore the magical part, this should be straightforward:
- Find the function by its name
- Compute the arguments
- Call the function with these arguments
We only have a small hiccup with (2): if it's a mutable argument - we know for sure it's a variable (we checked for that
in the previous post), so we can get its address directly from the _allocas
.
Now what about those first three lines? They deserve a separate section!
Even more magic
Previously we learned how to validate the arguments of magic functions. Now let's figure out how to evaluate them. We'll extend the interface with another method:
class MagicFunction(ABC):
@classmethod
@abstractmethod
def validate(cls, args, visit) -> types.DataType:
pass
@classmethod
@abstractmethod
def evaluate(cls, args, kinds, compiler):
pass
and this is how WriteLn
would implement it:
class WriteLn(MagicFunction):
@classmethod
def validate(cls, args, visit) -> types.DataType:
for arg in args:
visit(arg, None, False)
return types.Void
@classmethod
def evaluate(cls, args, kinds, compiler):
ptr = compiler.string_pointer(format_io(kinds) + b'\n\00')
return compiler.builder.call(compiler.module.get_global('printf'), [ptr, *compiler.visit_sequence(args, False)])
We're basically preparing arguments for printf - a function from the C world. It expects a format spec that depends on argument types:
from jboc import composed
@composed(b' '.join)
def format_io(args):
for arg in args:
match arg:
case types.SignedInt(_):
yield b'%d'
case types.Floating(_):
yield b'%f'
case types.Char:
yield b'%c'
case types.StaticArray(dims, types.Char) if len(dims) == 1:
yield b'%s'
case types.DynamicArray(types.Char):
yield b'%s'
case kind:
raise TypeError(kind)
We then just find printf
and call it. But to find it we first need to define it. We'll store all the definitions for
external functions here:
FFI = {
'printf': ir.FunctionType(ir.IntType(32), [ir.IntType(8).as_pointer()], var_arg=True),
'scanf': ir.FunctionType(ir.IntType(32), [ir.IntType(8).as_pointer()], var_arg=True),
'getchar': ir.FunctionType(ir.IntType(8), []),
'rand': ir.FunctionType(ir.IntType(32), []),
'srand': ir.FunctionType(ir.VoidType(), [ir.IntType(32)]),
'time': ir.FunctionType(ir.IntType(32), [ir.IntType(32)]),
}
FFI stands for foreign function interface
I defined a bunch of them just in case. Note that printf
and scanf
are variadic, just like writeln
and readln
in Pascal.
Finally, we'll add these definitions in Compiler
's constructor:
def __init__(self, ts: TypeSystem):
# ... other stuff
for name, kind in FFI.items():
ir.Function(self.module, kind, name)
Wait. What's with that string_pointer
function? We need to pass a pointer to the spec string to printf
. Here's the
solution I came up with. We'll make compile-time string global constants like so:
def string_pointer(self, value: bytes):
kind = ir.ArrayType(ir.IntType(8), len(value))
global_string = ir.GlobalVariable(self.module, kind, name=f'string.{self._string_idx}.global')
global_string.global_constant = True
global_string.initializer = ir.Constant(kind, [ir.Constant(ir.IntType(8), x) for x in value])
self._string_idx += 1
return self.builder.gep(
global_string, [ir.Constant(ir.IntType(32), 0), ir.Constant(ir.IntType(32), 0)]
)
We create a global variable with a unique name, _string_idx
makes sure of that. We initialize it with our string,
then we take its pointer and return it.
With this code in place, we can also add another branch to _const
:
def _const(self, node: Const, lvalue: bool):
value = node.value
match node.type:
case types.StaticArray(dims, types.Char) if len(dims) == 1:
return self.string_pointer(value)
# ... rest of the cases
Read
I'll show you a few more implementations real quick:
class Read(MagicFunction):
@classmethod
def validate(cls, args, visit) -> types.DataType:
if not args:
raise WrongType
for arg in args:
visit(arg, None, True)
return types.Void
@classmethod
def evaluate(cls, args, kinds, compiler):
ptr = compiler.string_pointer(format_io(kinds) + b'\00')
return compiler.builder.call(compiler.module.get_global('scanf'), [ptr, *compiler.visit_sequence(args, True)])
For Read
we simply call into scanf
instead of printf
.
ReadLn
Now ReadLn
is a totally different beast:
class ReadLn(MagicFunction):
@classmethod
def validate(cls, args, visit) -> types.DataType:
for arg in args:
visit(arg, None, True)
return types.Void
@classmethod
def evaluate(cls, args, kinds, compiler):
builder = compiler.builder
ptr = compiler.string_pointer(format_io(kinds) + b'\00')
builder.call(compiler.module.get_global('scanf'), [ptr, *compiler.visit_sequence(args, True)])
# ignore the rest of the line: while (getchar() != '\n') {} // ord('\n') == 10
check_block = builder.append_basic_block('check')
loop_block = builder.append_basic_block('loop')
end_block = builder.append_basic_block('end')
builder.branch(check_block)
# check
builder.position_at_end(check_block)
condition = builder.icmp_signed(
'!=', builder.call(compiler.module.get_global('getchar'), ()), ir.Constant(ir.IntType(8), 10)
)
builder.cbranch(condition, loop_block, end_block)
# loop
builder.position_at_end(loop_block)
builder.branch(check_block)
# exit
builder.position_at_end(end_block)
We first read several values, then we must skip the rest until we hit the line end. All the wizardry here is just translating this C code:
scanf(...);
while (getchar() != '\n') {}
Final preparations
Let's deal quickly with the boring stuff.
Type casting
Just like before, all type casting is happening in after_visit
:
def _cast(self, value, src: types.DataType, dst: types.DataType, lvalue: bool):
# references are just fancy pointers
if isinstance(src, types.Reference) and not lvalue:
src = src.type
value = self.builder.load(value)
if src == dst:
return value
match src, dst:
case types.SignedInt(_), types.Floating(_):
return self.builder.sitofp(value, resolve(dst))
case types.SignedInt(_), types.SignedInt(_):
return self.builder.sext(value, resolve(dst))
case types.StaticArray(_), types.DynamicArray(_):
return value
# ... more cases here perhaps
raise NotImplementedError(value, src, dst)
def after_visit(self, node, value, lvalue=None):
if node in self._ts.casting:
assert lvalue is not None
return self._cast(value, self._ts.types[node], self._ts.casting[node], lvalue)
return value
Desugaring
To write this one we'll need to extend the Visitor
interface:
class Visitor:
def visit(self, node, *args, **kwargs):
node = self.before_visit(node, *args, **kwargs)
value = getattr(self, f'_{snake_case(type(node).__name__)}')(node, *args, **kwargs)
value = self.after_visit(node, value, *args, **kwargs)
return value
def before_visit(self, node, *args, **kwargs):
return node
# ... other methods
So, before visiting the node, we can change it and force the compiler to visit something else - perfect for desugaring.
And that's how the Compiler
will implement it:
def before_visit(self, node, *args, **kwargs):
return self._ts.desugar.get(node, node)
d.get(x, x)
is just a quicker way to write d[x] if x in d else x
Magic registry
Last time we stored all our magic functions in a dict:
MAGIC_FUNCTIONS = {
'writeln': WriteLn,
}
For 2-3 functions this is fine, but it quickly becomes tedious. Here's a cool way to do this automatically:
class MagicFunction(ABC):
# ... other methods
def __init_subclass__(cls, **kwargs):
name = cls.__name__.lower()
assert name not in MAGIC_FUNCTIONS
MAGIC_FUNCTIONS[name] = cls
__init_subclass__
is a hook that is triggered right after a new subclass is created, you get the idea.
Optimization
At this point you should be able to finally print "Hello world!":
program main;
begin
writeln('Hello World!');
end.
But there's more we can do! An obvious strong side of LLVM is optimization - it has a lot of different optimizations. Today we'll just a few of them that seem self-explanatory:
# ....
# translate
module = llvm.parse_assembly(str(module))
module.verify()
# optimize
pm_builder = llvm.PassManagerBuilder()
pm = llvm.ModulePassManager()
pm_builder.populate(pm)
# add passes
pm.add_constant_merge_pass()
pm.add_instruction_combining_pass()
pm.add_reassociate_expressions_pass()
pm.add_gvn_pass()
pm.add_cfg_simplification_pass()
pm.add_loop_simplification_pass()
pm.run(module)
# run
llvm.initialize()
# ...
The pass manager pm
will do all the heavy lifting for us:
- merge constants
- reassociate expressions:
(1 + 2) * (2 + 1) -> (1 + 2) * (1 + 2)
- simplify expressions:
(1 + 2) * (1 + 2) -> (1 + 2) ** 2
- simplify branching: control flow and loops
- remove redundant instructions
You can print the module
before and after pm.run(module)
and see the difference. Pretty cool, right?
~
If you're reading this, thanks for sticking around! I hope you enjoyed reading this series of posts as much as I enjoyed implementing all this stuff.
There are a ton of ways one could improve this implementation, and it's by no means complete, but I feel like it still gives a good perspective of what are the key components. Let me know in the comments if there are some interesting aspects that I didn't cover!