Unification
Unification is the process of finding a substitution that makes two types equal. It’s the core algorithm for type inference.
Basic Algorithm
fn unify(&mut self, t1: &Type, t2: &Type) -> Result<(), TypeError> {
// Apply current substitution first
let t1 = self.apply_subst(t1);
let t2 = self.apply_subst(t2);
match (&t1, &t2) {
// Identical types unify trivially
_ if t1 == t2 => Ok(()),
// Type variable unifies with anything (with occurs check)
(Type::TypeVar(id), ty) | (ty, Type::TypeVar(id)) => {
self.bind_var(*id, ty)
}
// Compound types unify component-wise
(Type::List(a), Type::List(b)) => self.unify(a, b),
(Type::Option(a), Type::Option(b)) => self.unify(a, b),
(Type::Result(ok1, err1), Type::Result(ok2, err2)) => {
self.unify(ok1, ok2)?;
self.unify(err1, err2)
}
(Type::Tuple(ts1), Type::Tuple(ts2)) => {
if ts1.len() != ts2.len() {
return Err(TypeError::TupleLengthMismatch);
}
for (a, b) in ts1.iter().zip(ts2.iter()) {
self.unify(a, b)?;
}
Ok(())
}
(Type::Function { params: p1, ret: r1, .. },
Type::Function { params: p2, ret: r2, .. }) => {
if p1.len() != p2.len() {
return Err(TypeError::ParamCountMismatch);
}
for (a, b) in p1.iter().zip(p2.iter()) {
self.unify(a, b)?;
}
self.unify(r1, r2)
}
// Different types don't unify
_ => Err(TypeError::Mismatch {
expected: t1.clone(),
found: t2.clone(),
}),
}
}
Variable Binding
When binding a type variable, check for cycles:
fn bind_var(&mut self, var: TypeVarId, ty: &Type) -> Result<(), TypeError> {
// Check if variable is already bound
if let Some(existing) = self.substitution.get(&var) {
return self.unify(existing, ty);
}
// Skip if binding to self
if let Type::TypeVar(id) = ty {
if *id == var {
return Ok(());
}
}
// Occurs check - prevent infinite types
if self.occurs_in(var, ty) {
return Err(TypeError::InfiniteType {
var,
ty: ty.clone(),
});
}
// Add binding to substitution
self.substitution.insert(var, ty.clone());
Ok(())
}
Occurs Check
The occurs check prevents creating infinite types:
// Would create: T0 = [T0] = [[T0]] = ...
let xs = [xs] // Error: infinite type
fn occurs_in(&self, var: TypeVarId, ty: &Type) -> bool {
match ty {
Type::TypeVar(id) => {
if *id == var {
return true;
}
// Check through substitution
if let Some(resolved) = self.substitution.get(id) {
return self.occurs_in(var, resolved);
}
false
}
Type::List(elem) => self.occurs_in(var, elem),
Type::Option(inner) => self.occurs_in(var, inner),
Type::Result(ok, err) => {
self.occurs_in(var, ok) || self.occurs_in(var, err)
}
Type::Tuple(elems) => elems.iter().any(|e| self.occurs_in(var, e)),
Type::Function { params, ret, .. } => {
params.iter().any(|p| self.occurs_in(var, p))
|| self.occurs_in(var, ret)
}
// Primitives don't contain type variables
_ => false,
}
}
Substitution
A substitution maps type variables to types:
struct Substitution {
map: HashMap<TypeVarId, Type>,
}
impl Substitution {
fn apply(&self, ty: &Type) -> Type {
match ty {
Type::TypeVar(id) => {
self.map.get(id)
.map(|t| self.apply(t)) // Apply recursively
.unwrap_or_else(|| ty.clone())
}
Type::List(elem) => Type::List(Box::new(self.apply(elem))),
Type::Function { params, ret, caps } => Type::Function {
params: params.iter().map(|p| self.apply(p)).collect(),
ret: Box::new(self.apply(ret)),
caps: caps.clone(),
},
// Other compound types...
_ => ty.clone(),
}
}
}
Unification Examples
Simple Unification
unify(Int, Int) = Ok(())
unify(Int, String) = Err(Mismatch)
Variable Unification
unify(T0, Int) = Ok(substitution[T0] = Int)
unify(T0, T1) = Ok(substitution[T0] = T1)
Compound Unification
unify([T0], [Int])
= unify(T0, Int)
= Ok(substitution[T0] = Int)
unify((Int, T0), (Int, String))
= unify(Int, Int) = Ok
= unify(T0, String) = Ok(substitution[T0] = String)
Function Unification
unify((T0) -> T0, (Int) -> Int)
= unify(T0, Int) = Ok
= unify(T0, Int) = Ok (T0 already Int)
= Ok(substitution[T0] = Int)
Failure Cases
// Length mismatch
unify((Int, Int), (Int,)) = Err(TupleLengthMismatch)
// Type mismatch
unify([Int], {String: Int}) = Err(Mismatch)
// Occurs check failure
unify(T0, [T0]) = Err(InfiniteType)
Constraint-Based Approach
Instead of unifying immediately, collect constraints:
struct Constraint {
left: Type,
right: Type,
span: Span,
context: String,
}
impl TypeChecker {
fn add_constraint(&mut self, left: Type, right: Type, span: Span) {
self.constraints.push(Constraint {
left,
right,
span,
context: self.current_context(),
});
}
fn solve_constraints(&mut self) -> Result<(), Vec<TypeError>> {
let mut errors = Vec::new();
for constraint in &self.constraints {
if let Err(e) = self.unify(&constraint.left, &constraint.right) {
errors.push(e.with_span(constraint.span));
}
}
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
}
Benefits of constraint collection:
- Better error messages (know context)
- Can report multiple errors
- Enables advanced inference features
Union-Find Optimization
For efficiency, use union-find for type variables:
struct UnionFind {
parent: Vec<usize>,
rank: Vec<usize>,
}
impl UnionFind {
fn find(&mut self, x: usize) -> usize {
if self.parent[x] != x {
self.parent[x] = self.find(self.parent[x]); // Path compression
}
self.parent[x]
}
fn union(&mut self, x: usize, y: usize) {
let px = self.find(x);
let py = self.find(y);
if px != py {
// Union by rank
if self.rank[px] < self.rank[py] {
self.parent[px] = py;
} else if self.rank[px] > self.rank[py] {
self.parent[py] = px;
} else {
self.parent[py] = px;
self.rank[px] += 1;
}
}
}
}
This makes variable lookup nearly O(1) amortized.