Type Inference
Ori uses Hindley-Milner (HM) type inference, extended with features for patterns and capabilities.
How HM Inference Works
1. Fresh Type Variables
When a type is unknown, create a fresh type variable:
fn fresh_type_var(&mut self) -> Type {
let id = self.next_var;
self.next_var = TypeVarId(self.next_var.0 + 1);
Type::TypeVar(id)
}
2. Constraint Generation
Walk the AST and generate equality constraints:
fn infer_expr(&mut self, expr: ExprId) -> Type {
let expr_data = self.arena.get(expr);
match &expr_data.kind {
ExprKind::Literal(Literal::Int(_)) => Type::Int,
ExprKind::Ident(name) => {
self.env.lookup(*name).cloned()
.unwrap_or_else(|| self.error_undefined(*name))
}
ExprKind::Binary { left, op, right } => {
let left_ty = self.infer_expr(*left);
let right_ty = self.infer_expr(*right);
self.infer_binary_op(*op, left_ty, right_ty)
}
ExprKind::Let { name, value, body } => {
let value_ty = self.infer_expr(*value);
self.env.push_scope();
self.env.bind(*name, value_ty);
let body_ty = self.infer_expr(*body);
self.env.pop_scope();
body_ty
}
// ...
}
}
3. Unification
Solve constraints by unifying types:
fn unify(&mut self, t1: &Type, t2: &Type) -> Result<(), TypeError> {
let t1 = self.apply_subst(t1);
let t2 = self.apply_subst(t2);
match (&t1, &t2) {
// Same type - ok
(Type::Int, Type::Int) => Ok(()),
// Type variable - bind it
(Type::TypeVar(id), ty) | (ty, Type::TypeVar(id)) => {
if self.occurs_in(*id, ty) {
Err(TypeError::InfiniteType)
} else {
self.substitution.insert(*id, ty.clone());
Ok(())
}
}
// Compound types - recurse
(Type::List(a), Type::List(b)) => self.unify(a, b),
(Type::Function { params: p1, ret: r1 },
Type::Function { params: p2, ret: r2 }) => {
if p1.len() != p2.len() {
return Err(TypeError::ParamCountMismatch);
}
for (a, b) in p1.iter().zip(p2.iter()) {
self.unify(a, b)?;
}
self.unify(r1, r2)
}
// Mismatch
_ => Err(TypeError::Mismatch { expected: t1, found: t2 }),
}
}
4. Substitution Application
Apply the substitution to resolve type variables:
fn apply_subst(&self, ty: &Type) -> Type {
match ty {
Type::TypeVar(id) => {
if let Some(resolved) = self.substitution.get(id) {
self.apply_subst(resolved)
} else {
ty.clone()
}
}
Type::List(elem) => Type::List(Box::new(self.apply_subst(elem))),
Type::Function { params, ret } => Type::Function {
params: params.iter().map(|p| self.apply_subst(p)).collect(),
ret: Box::new(self.apply_subst(ret)),
},
_ => ty.clone(),
}
}
Inference Examples
Let Binding
let x = 42
let y = x + 1
1. x : T0 (fresh)
2. 42 : Int
3. Unify(T0, Int) -> substitution[T0] = Int
4. y : T1 (fresh)
5. x + 1 : lookup(+, Int, Int) = Int
6. Unify(T1, Int) -> substitution[T1] = Int
Function Application
@double (x: int) -> int = x * 2
double(21)
1. double : (Int) -> Int
2. 21 : Int
3. double(21) : apply (Int) -> Int to (Int) = Int
Generic Function
@identity<T> (x: T) -> T = x
identity(42)
identity("hello")
1. identity : forall T. (T) -> T
2. identity(42):
- Instantiate: (T0) -> T0
- Unify(T0, Int)
- Result: Int
3. identity("hello"):
- Instantiate: (T1) -> T1
- Unify(T1, String)
- Result: String
List Inference
let xs = [1, 2, 3]
let ys = map(over: xs, transform: x -> x * 2)
1. [1, 2, 3] : [T0] where each element unifies with T0
- 1 : Int, unify(T0, Int)
- Result: [Int]
2. map:
- over: [Int]
- transform: T1 -> T2
- Unify(T1, Int) from element type
- x * 2 : Int, so T2 = Int
- Result: [Int]
Let Polymorphism
Variables bound with let can be polymorphic:
let id = x -> x // forall T. T -> T
let a = id(42) // Int
let b = id("hello") // String
This is called “let-generalization”:
fn infer_let(&mut self, name: Name, value: ExprId, body: ExprId) -> Type {
let value_ty = self.infer_expr(value);
// Generalize: find unbound type variables
let generalized = self.generalize(value_ty);
self.env.bind(name, generalized);
self.infer_expr(body)
}
fn generalize(&self, ty: Type) -> Type {
// Find type variables not bound in environment
let free_vars = self.free_type_vars(&ty);
let env_vars = self.env.free_type_vars();
let generalizable = free_vars.difference(&env_vars);
if generalizable.is_empty() {
ty
} else {
Type::Forall { vars: generalizable.collect(), ty: Box::new(ty) }
}
}
Occurs Check
Prevent infinite types like T = [T]:
fn occurs_in(&self, var: TypeVarId, ty: &Type) -> bool {
match ty {
Type::TypeVar(id) => {
if *id == var {
true
} else if let Some(resolved) = self.substitution.get(id) {
self.occurs_in(var, resolved)
} else {
false
}
}
Type::List(elem) => self.occurs_in(var, elem),
Type::Function { params, ret } => {
params.iter().any(|p| self.occurs_in(var, p))
|| self.occurs_in(var, ret)
}
_ => false,
}
}
Error Reporting
When unification fails, report the mismatch:
Err(TypeError::Mismatch {
expected: Type::Int,
found: Type::String,
span: expr.span,
context: "in binary addition",
})
Output:
error[E2001]: type mismatch
--> src/mainsi:5:10
|
5 | 42 + "hello"
| ^^^^^^^ expected int, found str
|
= note: in binary addition