Compiling Pascal with LLVM: Part 3
Typing
Today we're going to write a bit less code that last time. But don't worry, I'll compensate that with a healthy scoop of (hopefully) new concepts!
Adding types
First, let's fix all the crutches we left in our parser. And we'll start this by adding types.
from dataclasses import dataclass
hashable = dataclass(unsafe_hash=True, repr=True)
class DataType:
""" Base class for all data types """
class VoidType(DataType):
pass
class BooleanType(DataType):
pass
class CharType(DataType):
pass
@hashable
class SignedInt(DataType):
bits: int
@hashable
class Floating(DataType):
bits: int
@hashable
class Pointer(DataType):
type: DataType
@hashable
class Reference(DataType):
type: DataType
@hashable
class StaticArray(DataType):
dims: tuple[tuple[int, int]]
type: DataType
@hashable
class DynamicArray(DataType):
type: DataType
@hashable
class Field:
name: str
type: DataType
@hashable
class Record(DataType):
fields: tuple[Field]
@hashable
class Signature:
args: tuple[DataType]
return_type: DataType
@hashable
class Function(DataType):
signatures: tuple[Signature]
So, we're representing types with... Python classes. Each type is an instance of DataType
. We have some basic stuff
like bool, void, ints and floats with different numbers of bits and two kinds pointers:
pointers and references.
There isn't much difference between them at runtime, it's more about language semantics. Finally, we have more complex
types like static (with dimensions known at compile time) or dynamic arrays and records
(structs, if you're coming from C).
Finally, we introduce a special type for overloaded functions. Each of them is just a collection of Signature
s - the
types of its arguments, and the return type.
Also let's create some useful types:
Ints = Byte, Integer = SignedInt(8), SignedInt(32)
Floats = Real, = Floating(64),
Void, Boolean, Char = VoidType(), BooleanType(), CharType()
TYPE_NAMES = {
'integer': Integer,
'real': Real,
'char': Char,
'byte': Byte,
'boolean': Boolean,
}
We'll use TYPE_NAMES
the parser later.
Now let's fix the parser, and add more types to it.
First we'll fix the nodes, by changing all type: str
to type: DataType
.
# I keep all the types in `types.py`
from . import types
@unique
class Const:
value: Any
type: types.DataType
@unique
class Definitions:
names: tuple[Name]
type: types.DataType
@unique
class ArgDefinition:
name: Name
type: types.DataType
@unique
class Function:
name: Name
args: tuple[ArgDefinition]
variables: tuple[Definitions]
body: tuple[Any]
return_type: types.DataType
Constants
def _primary(self):
match self.peek().type:
case TokenType.NUMBER:
body = self.consume().string
if '.' not in body:
value = int(body)
for kind in types.Ints:
if value.bit_length() < kind.bits:
return Const(value, kind)
return Const(float(body), types.Real)
case TokenType.STRING:
value = self.consume().string
if not value.startswith("'"):
raise ParseError('Strings must start and end with apostrophes')
value = eval(value).encode() + b'\00'
return Const(value, types.StaticArray(((0, len(value)),), types.Char))
# ... other cases are unchanged
Pretty straightforward. Strings are now just arrays of chars, floats and ints got a type instead of 'integer' and
'real'. Also, we're being a bit smarter here, and trying to pack integers in the smallest number of bits possible.
So 1
will be of type Byte
, while 1000
is an Integer
. Pascal has automatic type upcasting, so this is ok.
Definitions
Now this is the most tedious part. We'll have to add support for a bunch of type definitions. It's pretty straightforward though, so I won't waste you time describing what's going on here.
def _type(self):
if self.consumed(TokenType.CIRCUMFLEX):
return types.Pointer(self._type())
if self.consumed(TokenType.NAME, string='array'):
if self.consumed(TokenType.LSQB):
# true array
dims = [self._array_dims()]
while self.consumed(TokenType.COMMA):
dims.append(self._array_dims())
self.consume(TokenType.RSQB)
self.consume(TokenType.NAME, string='of')
internal = self._type()
return types.StaticArray(tuple(dims), internal)
# just a pointer
self.consume(TokenType.NAME, string='of')
internal = self._type()
return types.DynamicArray(internal)
# string is just a special case of an array
if self.consumed(TokenType.NAME, string='string'):
if self.consumed(TokenType.LSQB):
dims = self._array_dims(),
self.consume(TokenType.RSQB)
return types.StaticArray(dims, types.Char)
return types.DynamicArray(types.Char)
if self.consumed(TokenType.NAME, string='record'):
fields = []
while not self.consumed(TokenType.NAME, string='end'):
definition = self._definition()
for name in definition.names:
fields.append(types.Field(name.name, definition.type))
return types.Record(tuple(fields))
kind = self.consume(TokenType.NAME).string.lower()
return types.TYPE_NAMES[kind]
a lot of repetitive stuff, although I managed to move some stuff out to these small functions:
def _int(self):
neg = self.consumed(TokenType.OP, string='-')
value = int(self.consume(TokenType.NUMBER).string)
if neg:
return -value
return value
def _array_dims(self):
first = self._int()
if self.consumed(TokenType.DOT):
self.consume(TokenType.DOT)
return first, self._int()
return 0, first
By the way. Should the first part of array dims always be smaller than the second one? Or should we allow stuff like
array[10..1] of integer
? I would expect the compiler to be smart enough to statically detect this
Functions
Finally, let's fix that ugly crutch in _prototype
:
# replace
if mutable:
kind = f'reference({kind})'
# by
if mutable:
kind = types.Reference(kind)
The Visitor pattern
With all this in place, after parsing the code we get an abstract syntax tree or AST. From now on we'll do a lot of tree walking, which can be quite tedious if you do it without the proper tools.
In functional languages such a tool is pattern matching. Yes, Python also has pattern matching syntax, and we even used it already a few times. However, for tree walking I feel like it will get messy very quickly, because we'll be forced to cram all the code into a single function.
To keep things nicely separated we'll use the visitor pattern.
Another book recommendation: the Gang of Four's "Design Patterns".
Because we're using Python, a super dynamic language, I'll show you a handy way to implement the visitor pattern without the need to do type checks or modify the classes we visit:
import re
# credit: https://stackoverflow.com/a/1176023
first_cap = re.compile(r'(.)([A-Z][a-z]+)')
all_cap = re.compile(r'([a-z\d])([A-Z])')
def snake_case(name):
name = first_cap.sub(r'\1_\2', name)
return all_cap.sub(r'\1_\2', name).lower()
class Visitor:
def visit(self, node, *args, **kwargs):
value = getattr(self, f'_{snake_case(type(node).__name__)}')(node, *args, **kwargs)
value = self.after_visit(node, value, *args, **kwargs)
return value
def visit_sequence(self, nodes, *args, **kwargs):
return tuple(self.visit(node, *args, **kwargs) for node in nodes)
def after_visit(self, node, value, *args, **kwargs):
return value
All the fun is happening inside visit
. We're basically doing a kind of dynamic dispatch based on the class name. So,
if we call Visitor.visit(MyClass())
, inside it will get dispatched to Visitor._my_class(value)
.
I'm converting the class name from CamelCase
to snake_case
because of PEP8.
There's also a useful after_visit
method, which will come in handy pretty soon. Think of it as a post-visit hook.
Static analysis
Now that we're ready for tree walking, let's see what we actually want to do with our AST.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
|
- we can shadow global variables
- moreover, we can shadow with a variable of another type
- which variable we're referring here? local or global?
- this is how we define the return value
- we can overload functions
- referencing and changing a global variable
- assignments can be more complex
- we can call functions that take 0 arguments without parentheses
The little program above has all the concepts we need to catch.
Name tracking: for each Name
node, we need to know which variable or function it refers to, which can be
- a simple variable
- one of the overloaded functions that have the same name
- inside functions, we also have a special variable to assign the return value to
Type inference: for each node inside an expression we want to know its type as well as check for type errors and do type casting along the way.
Static dispatch: because there may be overloaded functions, we need to statically determine which function the user is referring to.
Now this is a lot of work, and it may be a good idea to split it into several passes of tree walking. But this will require a bit more code to write, and I want to run my "Hello world" as soon as possible.
So, here's our new visitor
class TypeSystem(Visitor):
def __init__(self):
# the actual types of nodes: Node -> DataType
self.types = {}
# what the nodes should be cast to: Node -> DataType
self.casting = {}
# what each `Name` node is referring to: Node -> Node
self.references = {}
each method for expression nodes will have the following signature:
def _my_node(self, node, expected: DataType | None, lvalue: bool) -> DataType:
# ...
We will specify which type the node is expected
to have, or None
, if we don't care. This will come in handy during
type checks. The methods will return the actual type of the node, e.g. for 1 + 1
it will probably return Byte
.
The last parameter, lvalue
, is more interesting. Consider this statement:
r[5].age = 25 + 1;
As we saw earlier, an assignment is basically two expressions delimited by the :=
token.
Each expression has a value, and, it turns out that values come in two colors: lvalues
and rvalues
.
l
and r
, you guessed it, stand for left
and right
respectively. So
r[5].age
is an lvalue
, and
25 + 1
is an rvalue
in our example.
The only difference between them, at least for us, will be that we'll expect lvalues
to return a pointer. This makes
sense, because we need an address in memory in which we'll store the rvalue
we just computed. This requirement also
covers cases like
1 + 2 := myfunc(3);
The type checker will complain that there's simply no way to compute the pointer to 1 + 2
.
Compile-time constants such as 1
or simple expressions like 1 + 2
might be optimized out by the compiler or
even stored in a register rather than RAM, and there are no
pointers to registers!
Finally, we'll store types and casting information in self.types
and self.casting
. The after_visit
method is the
best place to do this:
def after_visit(self, node, kind, expected=None, lvalue=None):
# node types
self.types[node] = kind
# optionally add type casting, if needed
if expected is not None:
if not self.can_cast(kind, expected):
raise WrongType(kind, expected)
# if no casting needed - just remove it
if kind == expected:
self.casting.pop(node, None)
else:
self.casting[node] = expected
kind = expected
return kind
we'll implement can_cast
later. For now let's just assume it knows all the casting rules, e.g.
can_cast(Byte, Integer) is True
but can_cast(Real, Char) is False
The scope
We need a place to store all the defined variables, a scope. Keep in mind, that functions can shadow global variables, so when we call a function we enter a new scope. Usually scopes are stored in a stack. When we enter a new scope we:
- push an empty scope (usually a dict) on top of the stack
- define new variables by writing to the top scope
- read variables by traversing the stack: start with the top scope and go down until you find the variable with the name you're looking for
- after you're done just pop the scope from the top of the stack
Let's look at an example:
program scope;
var a, b, c: integer;
function f(d: integer): integer:
var c, e: integer;
begin
f := a + b + c + d + e;
end;
begin
f(1);
end.
After calling the function f
we enter its scope, and our stack looks like this:
Global Local
+------------+ +------------+
| a, b, c, f | -> | c, d, e, f |
+------------+ +------------+
So, in f := a + b + c + d + e
, to find a
and b
we'll have to traverse the stack, because it's not present in
the current Local scope.
Note that f
is present in both scopes, and it even means different things: in Global it's the function f
, in
Local it's the special variable we're writing the return value to.
You might ask "what about recursion?" we want to be able to call the function inside its own body. Yes, we'll get to that shortly, don't worry!
Now let's implement all this behaviour. We'll need methods for entering and leaving the scope, as well as defining variables and functions and referencing them by name.
from contextlib import contextmanager
class TypeSystem(Visitor):
def __init__(self):
self._scopes = []
self._func_return_names = []
self.types = {}
self.casting = {}
self.references = {}
self.desugar = {}
@contextmanager
def _enter(self):
self._scopes.append({})
yield
self._scopes.pop()
So far so good, we're using a list
here as a stack, and a small
context manager _enter
has all the code we need for entering and
leaving a scope.
Now that we're in a scope, that's how we'll define new names in it:
def _store(self, name: str, kind: types.DataType, payload):
assert name not in self._scopes[-1]
self._scopes[-1][name] = kind, payload
we store a value of type kind
by its name
in the topmost scope. You'll see in a moment what payload
is for.
The symmetric operation is finding a value by its name:
def _resolve(self, name: str):
for scope in reversed(self._scopes):
if name in scope:
return scope[name]
raise KeyError(name)
The scopes are reversed
, because we want to iterate from the list's tail - the top of the stack.
And, finally, we store the information that a Name
refers to a given node like so:
def _bind(self, source, destination):
self.references[source] = destination
The compiler will need this information to quickly find the pointer to the right variable or function.
Program
With all the pieces in place, let's start with the main stuff: the program itself and the functions:
def _program(self, node: Program):
with self._enter():
# vars
for definitions in node.variables:
for name in definitions.names:
self._store_value(name, definitions.type)
# functions
functions = defaultdict(list)
for func in node.functions:
functions[func.name.normalized].append(func)
for name, funcs in functions.items():
funcs = {f.signature: f for f in funcs}
self._store(name, types.Function(tuple(funcs)), funcs)
self.visit_sequence(node.functions)
self.visit_sequence(node.body)
Here _store_value
is a small util method to store Name
nodes:
def _store_value(self, name: Name, kind: types.DataType):
self._store(name.normalized, kind, name)
self.types[name] = kind
This is handy because when we're defining a variable we already know its type, so we can store this info on the spot.
For functions, though, it's not as simple because of overloading - there are several functions with the
same name. That's why we use a defaultdict(list)
- this is a simple way to split a set of objects into groups, in our
case - functions. Next, we _store
a single entry for each function, but save the information that will help us
differentiate between them in the payload
.
We're not done with functions yet! Besides storing the functions name in the global scope, we need to visit
their
bodies and resolve all the local variables. Note that we start visiting the functions only after we've defined all of
them. This will step makes sure that recursion works as expected, because we can resolve a function's name even if we
didn't visit its body yet.
Finally, we simply visit each statement in the program's body.
For completeness, here's the body of visit_sequence
:
def visit_sequence(self, nodes, *args, **kwargs):
return tuple(self.visit(node, *args, **kwargs) for node in nodes)
and Name.normalized
is just
class Name:
name: str
@property
def normalized(self):
return self.name.lower()
which is handy, because Pascal is case-insensitive.
Function
Now that we're done with the hard part, visiting a Function
node should look almost identical:
def _function(self, node: Function):
with self._enter():
self._func_return_names.append((node.name, node.return_type))
self.types[node.name] = node.return_type
for arg in node.args:
self._store_value(arg.name, arg.type)
for definitions in node.variables:
for name in definitions.names:
self._store_value(name, definitions.type)
self.visit_sequence(node.body)
self._func_return_names.pop()
Most of the code here is about handling this weird "assign to function's name to define the return value" behaviour.
Honestly, I don't know how to handle this better, so here we go: we keep a stack of (return_type, function_name)
pairs, which we'll use later to resolve lvalue
s. At the very end we simply pop this pair from the stack.
The rest is pretty straightforward. Define the variables, don't forget about function arguments (which are also a kind of local variables) then visit each statement in the body.
Assignment
This is one of our main nodes. Assignments are the bridge between lvalues
and rvalues
:
def _assignment(self, node: Assignment):
kind = self.visit(node.target, expected=None, lvalue=True)
# no need to cast to reference in this case
if isinstance(kind, types.Reference):
kind = kind.type
self.visit(node.value, expected=kind, lvalue=False)
First, we get the type of the left side. Here expected
is None
because don't care which type we're going to store
in, we only care what
we'll store there. That's why we visit the right side by passing the type constraint that we
received from the left side.
Additionally, we unwrap the potential Reference
here: writing to a reference is the same as writing to a regular
variable, at least from type system's perspective.
Const
What's the type of a Const
node? Simple! We already stored the type while parsing:
def _const(self, node: Const, expected: types.DataType, lvalue: bool):
if lvalue:
raise WrongType(node)
return node.type
additionally we make sure here that we're not trying to assign anything to this node, i.e. it's an rvalue
.
Name
Here comes the moment of truth, this little method handles all the references to variables and functions
def _name(self, node: Name, expected: types.DataType, lvalue: bool):
# assignment to the function's name inside a function is definition of a return value
if lvalue and self._func_return_names:
kind, target = self._func_return_names[-1]
if kind != types.Void and target.name == node.name:
self._bind(node, target)
return kind
kind, target = self._resolve(node.normalized)
if isinstance(kind, types.Function):
self.desugar[node] = new = Call(node, ())
return self._call(new, expected, lvalue)
self._bind(node, target)
return kind
First we handle our ugly "assign to function name" case. We do this only if
- it's an
lvalue
- we're inside a function i.e.
_func_return_names
isn't empty - we're inside a non-Void function (not a procedure), so the return type isn't
Void
- the name we're referring to is the same as the function's name
If all these conditions are met, we bind
the current node to the function's return value.
Otherwise, we resolve
the name and just bind
it to the variable we found.
Finally, there's one more case we need to handle. As we saw before, you can call functions with 0 arguments without parentheses. This is 100% legal:
program legal;
begin
writeln;
end.
Looks like in the 70s programmers liked syntactic sugar even more than we do today.
That's why we do another check - if it's a function then it's actually a function call, and we need to replace the
current node with a Call(node, ())
- we desugar it and store this info to help the compiler.
Call
def _call(self, node: Call, expected: types.DataType, lvalue: bool):
if not isinstance(node.target, Name):
raise WrongType(node)
# get all the functions with this name
kind, targets = self._resolve(node.target.normalized)
if not isinstance(kind, types.Function):
raise WrongType(kind)
# choose the right function
signature = self._dispatch(node.args, kind.signatures, expected)
self._bind(node.target, targets[signature])
return signature.return_type
Handling calls is pretty straightforward:
- take the
target
, which must be aName
, functions aren't first class citizens in Pascal! resolve
the name and make sure we've found a function- if it's an overloaded function, choose the right variant based on the signatures (static dispatch)
bind
theName
node to the function we just chose
All the heavy lifting is done in our _dispatch
function:
def _dispatch(self, args: Sequence, signatures: Sequence[types.Signature], expected: types.DataType):
for signature in signatures:
if len(signature.args) != len(args):
continue
if not self.can_cast(signature.return_type, expected):
continue
try:
for arg, kind in zip(args, signature.args, strict=True):
if isinstance(kind, types.Reference) and not isinstance(arg, Name):
raise WrongType('Only variables can be mutable arguments')
self.visit(arg, expected=kind, lvalue=False)
except WrongType:
continue
return signature
raise WrongType(args, expected, signatures)
Also pretty simple, just loop over all the signatures we have and try to find a match based on the number of args
,
their types, and the expected
return type of the function. We also check along the way, that if an argument is
mutable we can only pass a variable to it.
In the end we just fail with a WrongType
if nothing was found.
Dereference
We're done with the hard part! The rest should be a piece of cake:
def _dereference(self, node: Dereference, expected: types.DataType, lvalue: bool):
target = self.visit(node.target, types.Pointer(expected), lvalue)
return target.type
visit the target
while expecting a pointer, then return the type we point to.
GetField
More or less the same here:
def _get_field(self, node: GetField, expected: types.DataType, lvalue: bool):
target = self.visit(node.target, expected=None, lvalue=False)
if isinstance(target, types.Reference):
target = target.type
if not isinstance(target, types.Record):
raise WrongType(target)
for field in target.fields:
if field.name == node.name:
return field.type
raise WrongType(target, node.name)
Visit the target, make sure we've got a record, find the right field by name and return its type.
GetItem
And here as well:
def _get_item(self, node: GetItem, expected: types.DataType, lvalue: bool):
target = self.visit(node.target, expected=None, lvalue=True)
if isinstance(target, types.Reference):
target = target.type
if not isinstance(target, (types.StaticArray, types.DynamicArray)):
raise WrongType(target)
ndims = len(target.dims) if isinstance(target, types.StaticArray) else 1
if len(node.args) != ndims:
raise WrongType(target, node.args)
args = self.visit_sequence(node.args, expected=types.Integer, lvalue=False)
args = [x.type if isinstance(x, types.Reference) else x for x in args]
if not all(isinstance(x, types.SignedInt) for x in args):
raise WrongType(node)
return target.type
The only difference is that arrays can have multiple indices and we must check that each index is an integer.
Unary
Pascal doesn't have many unary operators:
def _unary(self, node: Unary, expected: types.DataType, lvalue: bool):
if node.op == '@':
if not isinstance(expected, types.Pointer) or lvalue:
raise WrongType(node)
return types.Pointer(self.visit(node.value, expected=expected.type, lvalue=lvalue))
return self.visit(node.value, expected, lvalue)
In case of taking an address (@
) we check that it's not an lvalue
and that we're expected to return a Pointer
.
The rest are just +
, -
and not
, which all return the same type as their argument, so we just visit the value
with the same arguments.
Binary
Binary operators, as always, are a bit more interesting. There's a lot of type casting going on with them, e.g. we
want to easily add a Real
to an Integer
, which makes perfect sense in most situations.
For me the simplest solution is to treat binary operators as simple functions with 2 arguments. We'll create a
collection of such functions and use our _dispatch
method to do all the work:
_numeric = [*types.Ints, *types.Floats]
_homogeneous = {
'+': _numeric,
'*': _numeric,
'-': _numeric,
'/': _numeric,
'and': [types.Boolean],
'or': [types.Boolean],
}
_boolean = {
'=': _numeric,
'<': _numeric,
'<=': _numeric,
'>': _numeric,
'>=': _numeric,
'<>': _numeric,
}
BINARY_SIGNATURES = {
k: [types.Signature((v, v), v) for v in vs]
for k, vs in _homogeneous.items()
}
BINARY_SIGNATURES.update({
k: [types.Signature((v, v), types.Boolean) for v in vs]
for k, vs in _boolean.items()
})
I'm writing from memory here, so I might be wrong, but I'm pretty sure all the operators either return the same type
they received, or a Boolean
in case of logical operators. That's what the code from above does: it synthetically
generates a number of valid signatures for binary operators. So the _binary
method itself becomes as easy as:
def _binary(self, node: Binary, expected: types.DataType, lvalue: bool):
return self._dispatch([node.left, node.right], BINARY_SIGNATURES[node.op], expected).return_type
Not bad at all
Expression statement
We're done with expressions! Now to statements.
def _expression_statement(self, node: ExpressionStatement):
self.visit(node.value, expected=None, lvalue=False)
Super simple, just visit the expression, don't even care what's the return type.
If
def _if(self, node: If):
self.visit(node.condition, expected=types.Boolean, lvalue=False)
self.visit_sequence(node.then_)
self.visit_sequence(node.else_)
We visit the condition making sure it's Boolean
.
Then we unconditionally visit both branches. This contrasts with how If
is evaluated at runtime. For now we're
only interested in variables resolution and expression types, so we must visit both branches.
While
def _while(self, node: While):
self.visit(node.condition, expected=types.Boolean, lvalue=False)
self.visit_sequence(node.body)
Almost same thing here.
For
And the final node:
def _for(self, node: For):
counter = self.visit(node.name, expected=None, lvalue=True)
if not isinstance(counter, types.SignedInt):
raise WrongType(counter)
self.visit(node.start, expected=counter, lvalue=False)
self.visit(node.stop, expected=counter, lvalue=False)
self.visit_sequence(node.body)
For
has a counter variable which we assign values to, and it must be an integer.
In rest, this is just a combination of If
and While
, nothing new.
Casting rules
As I promised, here's the can_cast
method, that handles all type casting:
def can_cast(self, kind: types.DataType, to: types.DataType) -> bool:
# we either don't care (to is None) or they're both the same type
if to is None or kind == to:
return True
match kind, to:
# references are just a wrapper, so we'll ignore them
case types.Reference(src), _:
return self.can_cast(src, to)
case _, types.Reference(dst):
return self.can_cast(kind, dst)
# static arrays can be viewed as dynamic in some cases, if they're 1-dimensional
case types.StaticArray(dims, src), types.DynamicArray(dst):
return len(dims) == 1 and src == dst
# ints can be cast to floats
case types.SignedInt(_), types.Floating(_):
return True
# basic upcasting, e.g. Byte -> Integer
for family in types.SignedInt, types.Floating:
if isinstance(kind, family) and isinstance(to, family):
return kind.bits <= to.bits
# no luck
return False
I added comments to the relevant parts, so this should be pretty straightforward.
A bit of magic
Finally, there's one more important bit that we need to talk about: writeln
.
It turns out that Pascal does a bit of cheating, and exposes several magic functions, that can't be implemented
in the language itself. The most known of them is writeln
: it can accept
any number of arguments and each of them can have any type from a
long list of allowed types. So, once again, this is 100% legal:
writeln;
writeln(1);
writeln(1, 2.5);
writeln(1, 2.5, 'my string');
There is simply no way for us to try and squeeze this behaviour into our _dispatch
method, in a way, writeln
is
an infinite number of overloaded functions.
To do this we'll introduce a new concept:
from abc import ABC, abstractmethod
class MagicFunction(ABC):
@classmethod
@abstractmethod
def validate(cls, args, visit) -> DataType:
pass
Yes, at this point we could replace this class with a function, but we'll extend it later, so introducing a new type is super legit, I promise!
validate
will have to, wait for it, validate the incoming arguments and decide the return type of the function.
In case of writeln
this should look like so:
class WriteLn(MagicFunction):
@classmethod
def validate(cls, args, visit) -> DataType:
for arg in args:
# we can write ~almost~ anything, so we don't care about the type
visit(arg, None, False)
return types.Void
and there are a few places we'll need to add support for it, but first, let's create a registry of magic functions:
MAGIC_FUNCTIONS = {
'writeln': WriteLn,
}
Program
As soon as we enter the global scope, we must define all the magic functions:
def _program(self, node: Program):
with self._enter():
# magic
for name, magic in MAGIC_FUNCTIONS.items():
self._store(name, magic(), None)
# ... the rest
Call
Next, calling magic functions needs special treatment:
def _call(self, node: Call, expected: types.DataType, lvalue: bool):
if not isinstance(node.target, Name):
raise WrongType(node)
# get all the functions with this name
kind, targets = self._resolve(node.target.normalized)
if isinstance(kind, MagicFunction):
return kind.validate(node.args, self.visit)
# ... the rest
Name
And finally, we should desugar 0-arg calls:
# replace
if isinstance(kind, types.Function):
# by
if isinstance(kind, (types.Function, MagicFunction)):
That's it! Now we have a real-life type system. This took a while, but I hope you found something useful.
You probably noticed that this is already the third post in this series, and still there's no LLVM in sight. In the next and final post we'll fix that. Next time we'll use all the concepts we built so far to compile everything to LLVM's IR!