Rea: a theorem prover written in C++ (Part 3: The type system)
Recap & Introduction
Rea Code: https://github.com/lackhoa/rea-engine
This is a series about a theorem prover using C++. With the parser out of the way, in this part we are gonna attack the type system. Things are gonna get serious from here because this is the core part of the language, so be prepared because you’re gonna learn some weird stuff you don’t need anywhere else in the world. In fact, it gets so serious that I get mildly depressed trying to write this part. That said I’ll still try to add in as much useful practical information as I can.
Please refresh your memory from part 1, because I described in that part a bit about what the type system should do.
What are we doing today?
Last episode, we’ve built up the AST. But ASTs are made of tokens. This time, we convert the AST to terms, which is the final transformation that our source code will go through (phew!).
To motivate the concept of “term”, let us go through an example. Since I haven’t introduced rea language syntax yet, let’s demonstrate with C for now.
bool negate(bool x)
{
bool out = !x;
return out;
}
static int out = 5;
int return5()
{
return out;
}
So in both functions “negate” and “return5” above, there is a statement “return out”. Let’s forget the “return” and focus on the exact symbol “out”. They both look exactly the same at the AST level, but they represent different things:
- one is a boolean, the other is an int
- one is a variable whose value we don’t know at compile time, while the other is a constant
The type system is in charge of figuring all these out. As you can already see above, in order to figure out what the type of “out” is, we need to figure out exactly what the name “out” refers to. So the job of the type system is mostly to figure out the semantics of the program, ie the part of the program that doesn’t concern variable names. Some people call what the type system does “typing” or “typechecking”. But I prefer to call this process “term building” or just “building” for short, since the core part of the work isn’t about type.
In a sense, our job is to figure out what’s common in these two functions
bool negate(bool x)
{
bool out = !x;
return out;
}
bool negate2(bool x)
{
bool neg = !x;
return neg;
}
not that we need to compare functions for some reason.
The other way around
I’ve shown you that you can’t know type something has without knowing what it is. You’d think that typechecking would be a clean two step process:
- figure out what it is, and
- figure out what type it has.
But it’s not so straightforward, since the other way around is also true: you won’t know what something is without knowing its type!
In math you have all these objects that have the same name but with different types. Well, forget math, you can already see that from programming, like in C
auto x = 5; // int
auto y = 5.0f; // float
auto z = 5.0; // double
Having to write 5.0f
all the time is a pain.
Luckily, when we write auto x = 5 + 10.0f
, the compiler can automatically figure out that 5
should be a float.
This is done via what’s called “type promotion”.
C has it easy, because they only have to convert between a few types. C++ has it harder because they have user-defined type conversion. And of course they manage to mess it up somehow so I ended up not being able to rely on it completely (these guys should receive an award for accomplishing the largest number of half-baked features ever).
Speaking from experience, over-reliance on type conversion could lead to disasters. For example in C++ we can have this kind of scenario
struct Bar
{
int value;
operator bool() { return value; }
};
int foo(bool x) {return x;}
int foo(Bar x) {return x.value + 1;} // <- imagine commenting this line out
int main()
{
Bar x = Bar{5};
return foo(x);
}
The code above will return either 6 or 1, depending on whether the int foo(Bar x)
exists or not.
Einstein used to call this “spooky action at a distance” back in the days.
Unfortunately, just like in theoretical physics, sometimes we don’t have a choice whether some things exist in our programming language or not. This is sadly one of those times.
Well perhaps not so sadly, because it is the name of the game. As the great mathematician Henri Poincare put it, “mathematics is the art of giving the same names to different things”. Think of the symbol “1” for example, it might stand for
- A natural number
- A fraction
- A real number
- An identity matrix
- An imaginary scalar
- The identity function
- … the list goes on and on forever
even if we were to try to give separate names to all those things, we would have a hard time coming up with them. Even in programming, we can’t be God.
So now we have a seemingly impossible problem: we can’t typecheck without semantics information. And we can’t obtain semantics without the type. There is actually a way to solve this problem satisfactorily, and I will explain it eventually once we’ve gotten far enough.
So what are we gonna do?
This is the point where it’s hard to even think about what to do. Let’s roll with our usual approach and set up an example to see what our requirements would be. This is probably the first line of code I wrote using rea
Bool :: union {false, true}
& :: fn (a,b: Bool) -> Bool
{fork a
{false: return false,
true: return b}}
Let me describe in English what the code above would accomplish if the system is working correctly. I will trust you to infer the syntax :>
- We define a new type called “Bool”. This is kind of a like an enum in C, with two members: “false” or “true”.
- We define “&” to be a function with type
(a,b: Bool) -> Bool
, meaning that it takes two Booleans “a” and “b”, and returns another Boolean. Thefork a
is like a C switch statement ona
: so if “a” is false, we return “false”; and if “a” is true, we return the value of “b”.
What would it take to understand the semantics of “&”? If we do it by hand, it would go straightforwardly like this:
- First look at the signature
(a,b: Bool) -> Bool
, we have to figure out its semantics.-
The first thing to notice is that it’s an “arrow” type (see the arrow?). Arrow types is a kind of type that describes functions.
- The
(a,b: Bool)
part says that it takes two arguments of type “Bool”.- What is Bool? Oh earlier we saw that “Bool” was defined to be a type with two true/false values. So good, at least this part makes sense.
- The
-> Bool
part means that it returns a result of type “Bool”. Similarly we lookup the meaning of Bool and find that it was defined.
-
- Next we check the body of the function:
{fork a {false: return false, true: return b}}
, let’s obtain its semantics.- The form of this body is “fork”.
- What is it forking? “a”.
- What is “a”? We see that there’s one parameter called “a”, so it must be that.
- Since “a” is a “Bool” (as we have figured out from the signature), the fork needs two branches for “false” and “true”. And it does!
- We check the branch “false”. It says that we return “false”.
- What is “false”? It is a member of type “Bool”. Does that fit the bill? Yes, we are trying to return Bool.
- (I’ll leave you to typecheck the other branch)
- Conclusion: this function is well-typed!
Meaning that we understand what it claims to do (
(a,b: Bool) -> Bool
), and that it does exactly that.
Understanding is good and all, but we need to store this understanding as a data structure. That data structure is called a “term”. Let me show you what the term for “&” looks like, in a cursory manner
Bool :: union {false, true}
& :: fn (a,b: Bool) -> Bool
{fork a
{false: return false,
true: return b}}
(Suppose 0xB001 addresses "Bool", 0x0000 addresses "false")
Here is the term representing "&"
- Name: "&"
- Type: an arrow type
+ Parameters: 2
* Name: "a"; Type: 0xB001
* Name: "b"; Type: 0xB001
+ Return type: 0xB001
- Body: a fork
+ Fork subject: variable [0,0] (meaning "a", explanation below)
+ Branches: 2
* Expression: 0x0000
* Expression: variable [0,1] (meaning "b", explanation below)
The above is straightforward, except the mysterious part is how to represent variables.
You can already kinda guess how we do it from above: the function has two parameters “a” and “b”, so [0,0]
represents the first parameter; and [0,1]
represents the second parameter.
The first 0
is used to distinguish variables of different functions, because you know: sometimes a function needs to reference parameters of another function.
There is only one scenario where that can happen.
We’ll demonstrate it with the funciton “&’” below
&' :: fn (a: Bool) -> Bool -> Bool
{return fn (b : Bool) -> Bool
{fork a
{false: false,
true: return b}}}
“&`” is functionally the same way as “&”, except that it only accepts one Boolean as argument. But instead of returning a Boolean result right away, it returns another function. That inner function accepts one Boolean and returns the final Boolean result. This is the term breakdown of “&’”
(Suppose 0xB001 addresses "Bool", 0x0000 addresses "false")
- Name: "&'"
- Type: an arrow type
+ Parameters: 1
* Name: "a"; Type: 0xB001
+ Return type: an arrow type
* Parameters: 1
- Name: "<none>"; Type: 0xB001
- Body: a function
- Type: an arrow type
+ Parameters: 1
* Name: "b"; Type: 0xB001
+ Return type: 0xB001
- Body: a fork
+ Fork subject: variable [1,0] (meaning "a")
+ Branches: 2
* Expression: 0x0000
* Expression: variable [0,0] (meaning "b")
Notice the fork subject is the variable “[1,0]”, meaning that instead of referring to the parameter “b” (of the inner function), it refers to the parameter “a” (of the outer function “&’”). I say that this variable has “delta=1” (and “index=0”). I hope you can see how a delta of 2 allows an inner function of an inner function to refer to the parameters of the outermost function, and so on. The nesting of functions is the only place where a function can refer to something that is not one of its parameters.
Detailed implementation
This is the point where it gets hard to show all the details, because the actual implementation handles many more complications than what was shown above. But I’ll try my best to remove the parts of the code you shouldn’t see yet. So here’s the code that would handle term building
enum TermKind {
Term_Union,
Term_Constructor,
Term_Function,
Term_Fork,
Term_Variable,
Term_Composite,
Term_Arrow,
};
struct Term {
TermKind kind;
i32 serial;
Term *type;
Token *global_name;
};
enum TermKind {
Term_Union,
Term_Constructor,
Term_Function,
Term_Fork,
Term_Variable,
Term_Composite,
Term_Arrow,
Term_Pointer,
};
struct Term {
TermKind kind;
i32 serial;
Term *type;
Token *global_name;
};
enum PointerKind
{
Pointer_Stack,
Pointer_Heap // we don't need to care about this yet!
};
struct Variable : Term {
String name;
i32 delta;
i32 index;
};
struct StackPointer {
String name;
i32 depth;
i32 index;
Term *value;
};
sstruct Pointer : Term {
Record *ref;
PointerKind pointer_kind;
StackPointer stack;
};
struct LocalBinding
{
i32 hash;
String key;
i32 var_index;
LocalBinding *tail;
};
struct LocalBindings {
Arena *arena;
LocalBinding table[128];
LocalBindings *tail;
};
struct Scope {
Scope *outer;
i32 depth;
i32 param_count;
Pointer **pointers;
};
struct Typer {
LocalBindings *bindings;
Scope *scope;
};
internal Term *
buildComposite(Typer *typer, CompositeAst *in, Term *goal)
{
// hidden code: we don't care about this yet
}
struct LookupCurrentFrame { LocalBinding* slot; b32 found; };
internal LookupCurrentFrame
lookupCurrentFrame(LocalBindings *bindings, String key, b32 add_if_missing) {
// hidden code, basically lookup the key in bindings->table
// This part is omitted partly because it does a hash table lookup,
// and I have no idea why I did that instead of just a linear search...
}
inline Term *
lookupLocalName(Typer *typer, Token *token)
{
String name = token->string;
LocalBindings *bindings = typer->bindings;
Term *out = {};
Scope *scope = typer->scope;
for (i32 stack_delta = 0; bindings; stack_delta++)
{
LookupCurrentFrame lookup = lookupCurrentFrame(bindings, name, false);
if (lookup.found)
{
assert(lookup.slot->var_index < scope->param_count);
out = scope->pointers[lookup.slot->var_index];
break;
}
else
{
bindings = bindings->tail;
scope = scope->outer;
}
}
return out;
}
inline Pointer *
newStackPointer(String name, i32 depth, i32 index, Term *type)
{
assert(depth);
Pointer *out = newTerm(temp_arena, Pointer, type);
out->pointer_kind = Pointer_Stack;
out->stack.name = name;
out->stack.depth = depth;
out->stack.index = index;
return out;
}
// This function just recursively turn stack pointers back to variables
inline Term *
toAbstractTerm_(AbstractContext *ctx, Term *in0)
{
assert(ctx->zero_depth >= ctx->env_depth);
Term *out0 = 0;
Arena *arena = ctx->arena;
if (isGlobalValue(in0))
{
out0 = in0;
}
else
{
switch (in0->kind)
{
case Term_Pointer:
{
Pointer *in = castTerm(in0, Pointer);
if (in->pointer_kind == Pointer_Stack)
{
if (in->stack.depth > ctx->env_depth)
{
Variable *out = newTerm(arena, Variable, 0);
out->name = in->stack.name;
out->delta = ctx->zero_depth - in->stack.depth;
out->index = in->stack.index;
assert(out->delta >= 0);
out0 = out;
}
else
{
out0 = in0;
}
}
} break;
case Term_Arrow:
{
Arrow *in = castTerm(in0, Arrow);
Arrow *out = copyTerm(arena, in);
ctx->zero_depth++;
allocateArray(arena, in->param_count, out->param_types);
for (i32 i=0; i < in->param_count; i++)
{
out->param_types[i] = toAbstractTerm_(ctx, in->param_types[i]);
}
out->output_type = toAbstractTerm_(ctx, in->output_type);
ctx->zero_depth--;
out0 = out;
} break;
case Term_Fork:
{
Fork *in = castTerm(in0, Fork);
Fork *out = copyTerm(arena, in);
out->subject = toAbstractTerm_(ctx, in->subject);
out->cases = pushArray(arena, in->case_count, Term*);
for (i32 i=0; i < in->case_count; i++)
{
out->cases[i] = toAbstractTerm_(ctx, in->cases[i]);
}
out0 = out;
} break;
case Term_Function:
{
Function *in = castTerm(in0, Function);
Function *out = copyTerm(arena, in);
ctx->zero_depth++;
out->body = toAbstractTerm_(ctx, in->body);
ctx->zero_depth--;
out0 = out;
} break;
case Term_Variable:
{
out0 = copyTerm(arena, (Variable *)in0);
} break;
}
if (out0 != in0)
{
out0->type = toAbstractTerm_(ctx, in0->type);
}
}
assert(out0);
return out0;
}
inline void
addLocalBinding(Typer *typer, String key, i32 var_index)
{
auto lookup = lookupCurrentFrame(typer->bindings, key, true);
lookup.slot->var_index = var_index;
}
inline Scope *
introduceSignature(Scope *outer_scope, Arrow *signature)
{
i32 param_count = signature->param_count;
Scope *scope = newScope(outer_scope, param_count);
for (i32 param_i=0; param_i < param_count; param_i++)
{
String name = signature->param_names[param_i];
Term *type = toValue(scope, signature->param_types[param_i]);
// :fixpoint-pointers
scope->pointers[param_i] = newStackPointer(name, scope->depth, param_i, type);
}
return scope;
}
inline Term *
toValue(Scope *scope, Term *in0)
{
// hidden code: let's just say this doesn't do anything for now.
}
inline LocalBindings *
extendBindings(Typer *typer)
{
LocalBindings *out = pushStruct(temp_arena, LocalBindings, true);
out->tail = typer->bindings;
out->arena = temp_arena;
typer->bindings = out;
return out;
}
inline Function *
buildFunctionGivenSignature(Typer *typer0, Arrow *signature, Ast *in_body,
Token *global_name=0, FunctionAst *in_fun=0)
{
i32 unused_var serial = DEBUG_SERIAL;
// :persistent-global-function.
// NOTE: We MUST control the arena, otherwise the self-reference will be
// invalidated (TODO: maybe the copier should know about self-reference).
Arena *arena = global_name ? global_state.top_level_arena : temp_arena;
Function *out = newTerm(arena, Function, signature);
if (in_fun) out->function_flags = in_fun->function_flags;
Term *body = 0;
Typer typer_ = *typer0;
{
Typer *typer = &typer_;
typer->scope = introduceSignature(typer->scope, signature);
extendBindings(typer);
if (noError())
{
for (i32 index=0; index < signature->param_count; index++)
{
String name = signature->param_names[index];
addLocalBinding(typer, name, index);
}
Term *body_type = toValue(typer->scope, signature->output_type);
if ((body = buildTerm(typer, in_body, body_type)))
{
if (signature->output_type == hole)
{
// write back the output type
signature->output_type = toAbstractTerm(temp_arena, body->type, getScopeDepth(typer0));
Term *check = toValue(typer->scope, signature->output_type);
assertEqualNorm(body->type, check);
}
}
}
}
if (body)
{
out->body = toAbstractTerm(temp_arena, body, getScopeDepth(typer0));
if (global_name)
{
out->body = copyToGlobalArena(out->body);
}
}
NULL_WHEN_ERROR(out);
return out;
}
inline b32
matchType(Term *actual, Term *goal)
{
// hidden code: just pretend that it's deep-comparing the terms
// but in reality we have to normalize the terms first and compare them
}
forward_declare internal BuildTerm
buildTerm(Typer *typer, Ast *in0, Term *goal0)
{
// beware: Usually we mutate in-place, but we may also allocate anew.
i32 UNUSED_VAR serial = DEBUG_SERIAL++;
assert(goal0);
Arena *arena = temp_arena;
Term *value = 0;
b32 should_check_type = true;
b32 recursed = false;
switch (in0->kind)
{
case Ast_CompositeAst:
{
value = buildComposite(typer, (CompositeAst*)in0, goal0);
} break;
case Ast_Identifier:
{
Identifier *in = castAst(in0, Identifier);
Token *name = &in->token;
if (Term *v = lookupLocalName(typer, name))
{
value = v;
}
else
{
if (GlobalSlot *slot = lookupGlobalName(name))
{
for (i32 value_id = 0; value_id < slotCount(slot); value_id++)
{
Term *slot_value = slot->terms[value_id];
if (matchType((slot_value)->type, goal0))
{
value = slot_value;
}
}
}
if (value)
{
should_check_type = false;
}
}
} break;
case Ast_ArrowAst:
{
ArrowAst *in = (ArrowAst *)(in0);
Arrow *out = newTerm(arena, Arrow, rea.Type);
i32 param_count = in->param_count;
// :arrow-copy-done-later
out->param_count = param_count;
out->param_names = in->param_names;
out->param_flags = in->param_flags;
if (noError())
{
Scope *scope = typer->scope = newScope(typer->scope, param_count);
extendBindings(typer);
out->param_types = pushArray(arena, param_count, Term*);
for (i32 i=0; i < param_count && noError(); i++)
{
if (Term *param_type = buildTerm(typer, in->param_types[i], hole))
{
out->param_types[i] = toAbstractTerm(arena, param_type, getScopeDepth(scope->outer));
String name = in->param_names[i];
scope->pointers[i] = newStackPointer(name, scope->depth, i, param_type);
addLocalBinding(typer, name, i);
}
}
if (noError())
{
value = out;
// :no-output-type
if (in->output_type)
{
if (Term *output_type = buildTerm(typer, in->output_type, hole))
{
out->output_type = toAbstractTerm(arena, output_type, getScopeDepth(scope->outer));
}
}
else
{
out->output_type = hole;
}
}
unwindBindingsAndScope(typer);
}
} break;
case Ast_FunctionAst:
{
FunctionAst *in = castAst(in0, FunctionAst);
Term *type = goal0;
if (in->signature)
{
if (isNameOnlyArrowType(in->signature))
{
type = buildNameOnlyArrowType(typer, in->signature, goal0);
}
else
{
type = buildTerm(typer, in->signature, hole);
}
}
if (noError())
{
if (Arrow *signature = castTerm(type, Arrow))
{
value = buildFunctionGivenSignature(typer, signature, in->body, 0, in);
}
else
{
reportError(in0, "function signature is required to be an arrow type");
attach("type", type);
}
}
} break;
case Ast_ForkAst:
{
ForkAst *fork = castAst(in0, ForkAst);
value = buildFork(typer, fork, goal0);
recursed = true;
} break;
}
if (noError())
{// typecheck if needed
Term *actual = value->type;
if (should_check_type && !recursed)
{
if (!matchType(actual, goal0))
{
if (expectedWrongType(typer)) silentError();
else
{
reportError(in0, "actual type differs from expected type");
attach("got", actual);
}
}
}
}
return BuildTerm{.value=value};
}
The terminology “StackPointer” is used because it has a dual notion “HeapPointer”, which we haven’t gotten into yet. Just think of pointers as a representation for variables as I explained above. You might notice that the code for building identifiers is a bit more complicated than necessary, well that’s because we need to do something more complicated later on.
These are the places where local variables are concerned
// "buildTerm" -> arrow type: build and record the parameter types
// so that later on the function building code can reuse them (kinda)
case Ast_ArrowAst:
{
for (i32 i=0; i < param_count && noError(); i++)
{
if (Term *param_type = buildTerm(typer, in->param_types[i], hole))
{
out->param_types[i] = toAbstractTerm(arena, param_type, getScopeDepth(scope->outer));
...
}
}
} break;
// "buildTerm" -> identifier
case Ast_Identifier:
{
Identifier *in = castAst(in0, Identifier);
Token *name = &in->token;
if (Term *v = lookupLocalName(typer, name))
{
value = v;
}
else ...
} break;
// buildFunctionGivenSignature
typer->scope = introduceSignature(typer->scope, signature);
extendBindings(typer);
if (noError())
{
for (i32 index=0; index < signature->param_count; index++)
{
String name = signature->param_names[index];
addLocalBinding(typer, name, index); // <- bindings are added here
}
if ((body = buildTerm(typer, in_body, body_type)))
}
In short, we build the term with a typer (I should have called this “builder”, oh well too late). The typer holds the current bindings and the scope. The bindings and the scope go hand-in-hand. The typer builds identifiers as follows:
- It looks up variable names in the local bindings
- If the current bindings contains the name, then we return the corresponding pointer that is held in the scope.
- If the current bindings don’t contain the name, we consult the outer bindings.
- If we’ve run out of local bindings, we consult the global bindings.
In the end, pointers will just be turned to variables by the function “toAbstractTerm” (with the correct delta and index). You are right to ask: why don’t we just turn identifiers into variables, rather than pointers? The answer (as we’ll see later) is that we cannot afford to let variables freely flow through the system, since they don’t mean anything when they stand alone. You’ll see how variables flow through the type system later on when we get to using dependent types.
Conclusion
So I’ve told you about terms, and how to represent variables. For now you should just think in a cursory way about how that might work. I cannot explain to you the details yet, because the way we handle variables have to also work with dependent types. Maybe we’ll get to that next time.