Just Join

AVL trees implemented using just the join operation

Just Join for Parallel Ordered Sets

Guy E. Blelloch, Daniel Ferizovic, Yihan Sun 28th ACM Symposium on Parallelism in Algorithms and Architectures, 2016 https://www.cs.cmu.edu/~guyb/papers/BFS16.pdf

In the paper above, the authors implement various kinds of balanced binary search trees on top of a `join` operation. This includes the case of AVL trees, for which the authors prove that `join` preserves the AVL property (Lemma 1 in the paper).

In the proof below, we verify this lemma using Why3 (the AVL property, not the complexity). The paper skips the details regarding the AVL property---“The resulting tree satisfies the AVL invariants since rotations are used to restore the invariant (details left out)”---but the proof happens to be subtle. See CRITICAL below.

Authors: Jean-Christophe Filliâtre (CNRS) Paul Patault (Univ Paris-Saclay)

```use int.Int
use int.MinMax

type elt
```

the type `elt` of elements, ordered with `lt`

```val (=) (x y: elt) : bool
ensures { result <-> x=y }

val predicate lt elt elt
clone relations.TotalStrictOrder with
type t = elt, predicate rel = lt, axiom .

type tree = E | N int tree elt tree
```

the type of AVL trees, with the height stored in the first component so that we get the height in O(1) with function `ht`

```let function ht (t: tree) : int =
match t with E -> 0 | N h _ _ _ -> h end

let function node (l: tree) (x: elt) (r: tree) : tree =
N (1 + max (ht l) (ht r)) l x r

let rec ghost function height (t: tree) : int
ensures { result >= 0 }
= match t with
| E         -> 0
| N _ l _ r -> 1 + max (height l) (height r)
end

predicate wf (t: tree) =
match t with
| E         -> true
| N h l x r -> h = height t && wf l && wf r
end
```

trees are well-formed i.e. the height stored in the nodes is correct

AVL are binary search trees

```predicate mem (y: elt) (t: tree) =
match t with
| E         -> false
| N _ l x r -> mem y l || y=x || mem y r
end

predicate tree_lt (t: tree) (y: elt) =
forall x. mem x t -> lt x y

predicate lt_tree (y: elt) (t: tree) =
forall x. mem x t -> lt y x

predicate bst (t: tree) =
match t with
| E         -> true
| N _ l x r -> bst l && tree_lt l x && bst r && lt_tree x r
end

predicate avl (t: tree) =
match t with
| E         -> true
| N _ l _ r -> avl l && avl r && -1 <= height l - height r <= 1
end
```

AVL height invariant

Code starts here

Note: It is a pity that the specification for `rotate_left` and `rotate_right` is longer than the code, but we can't make them logical functions since they are partial functions.

```let rotate_left (t: tree) : (r: tree)
requires { wf t  } ensures { wf r  }
requires { bst t } ensures { bst r }
requires { match t with N _ _ _ (N _ _ _ _) -> true | _ -> false end }
ensures  { match t with N _ a x (N _ b y c) ->
match r with N _ (N _ ra rx rb) ry rc ->
ra=a && rx=x && rb=b && ry=y && rc=c
| _ -> false end | _ -> false end }
= match t with
| N _ a x (N _ b y c) -> node (node a x b) y c
| _ -> absurd
end

let rotate_right (t: tree) : (r: tree)
requires { wf t  } ensures { wf r  }
requires { bst t } ensures { bst r }
requires { match t with N _ (N _ _ _ _) _ _ -> true | _ -> false end }
ensures  { match t with N _ (N _ a x b) y c ->
match r with N _ ra rx (N _ rb ry rc) ->
ra=a && rx=x && rb=b && ry=y && rc=c
| _ -> false end | _ -> false end }
= match t with
| N _ (N _ a x b) y c -> node a x (node b y c)
| _ -> absurd
end

let rec join_right (l: tree) (x: elt) (r: tree) : tree
requires { wf l && wf r } ensures { wf result }
requires { bst l && tree_lt l x }
requires { bst r && lt_tree x r } ensures { bst result }
ensures  { forall y. mem y result <-> (mem y l || y=x || mem y r) }
requires { height l >= height r + 2 } variant { height l }
requires { avl l && avl r } ensures { avl result }
(* CRITICAL *)
ensures  { height result = height l ||
height result = height l + 1 && match result with
| N _ rl _ rr ->
height rl = height l - 1 && height rr = height l
| _ -> false end }
= match l with
| N _ ll lx lr ->
if ht lr <= ht r + 1 then
let t = node lr x r in
if ht t <= ht ll + 1 then node ll lx t
else rotate_left (node ll lx (rotate_right t))
else
let t = join_right lr x r in
let t' = node ll lx t in
if ht t <= ht ll + 1 then t' else rotate_left t'
(*                                ^^^^^^^^^^^^^^
The CRITICAL postcondition is used here
to show that the rotated tree is indeed an AVL. *)
| E -> absurd
end

let rec join_left (l: tree) (x: elt) (r: tree) : tree
requires { wf l && wf r } ensures { wf result }
requires { bst l && tree_lt l x }
requires { bst r && lt_tree x r } ensures { bst result }
ensures  { forall y. mem y result <-> (mem y l || y=x || mem y r) }
requires { height r >= height l + 2 } variant { height r }
requires { avl l && avl r } ensures { avl result }
(* CRITICAL *)
ensures  { height result = height r ||
height result = height r + 1 && match result with
| N _ rl _ rr ->
height rr = height r - 1 && height rl = height r
| _ -> false end }
= match r with
| N _ rl rx rr ->
if ht rl <= ht l + 1 then
let t = node l x rl in
if ht t <= ht rr + 1 then node t rx rr
else rotate_right (node (rotate_left t) rx rr)
else
let t = join_left l x rl in
let t' = node t rx rr in
if ht t <= ht rr + 1 then t' else rotate_right t'
(*                                ^^^^^^^^^^^^^^^ *)
| E -> absurd
end

let join (l: tree) (x: elt) (r: tree) : tree
requires { wf l && wf r } ensures { wf result }
requires { bst l && tree_lt l x }
requires { bst r && lt_tree x r } ensures { bst result }
ensures  { forall y. mem y result <-> (mem y l || y=x || mem y r) }
requires { avl l && avl r } ensures { avl result }
ensures  { height result <= 1 + max (height l) (height r) }
=      if ht l > ht r + 1 then join_right l x r
else if ht r > ht l + 1 then join_left  l x r
else                         node       l x r

```

The remaining is much simpler.

```let rec split (t: tree) (y: elt) : (l: tree, b: bool, r: tree)
requires { wf t && bst t && avl t }
variant  { height t }
ensures  { wf l && bst l && avl l } ensures { tree_lt l y }
ensures  { wf r && bst r && avl r } ensures { lt_tree y r }
ensures  { forall x. mem x t <-> (mem x l || mem x r || b && x=y) }
= match t with
| E -> E, false, E
| N _ l x r ->
if y = x then l, true, r
else if lt y x then let ll, b, lr = split l y in ll, b, join lr x r
else                let rl, b, rr = split r y in join l x rl, b, rr
end

let insert (x: elt) (t: tree) : (r: tree)
requires { wf t && bst t && avl t }
ensures  { wf r && bst r && avl r }
ensures  { forall y. mem y r <-> (mem y t || y=x) }
= let l, _, r = split t x in
join l x r

let rec split_last (t: tree) : (r: tree, m: elt)
requires { t <> E }
requires { wf t && bst t && avl t }
variant  { height t }
ensures  { wf r && bst r && avl r }
ensures  { forall x. mem x t <-> (mem x r && lt x m || x=m) }
ensures  { tree_lt r m }
= match t with
| N _ l x E -> l, x
| N _ l x r -> let r', m = split_last r in join l x r', m
| _ -> absurd
end

let join2 (l r: tree) : (t: tree)
requires { wf l && bst l && avl l }
requires { wf r && bst r && avl r }
requires { forall x y. mem x l -> mem y r -> lt x y }
ensures  { wf t && bst t && avl t }
ensures  { forall x. mem x t <-> (mem x l || mem x r) }
= match l with
| E -> r
| _ -> let l, k = split_last l in join l k r
end

let delete (x: elt) (t: tree) : (r: tree)
requires { wf t && bst t && avl t }
ensures  { wf r && bst r && avl r }
ensures  { forall y. mem y r <-> (mem y t && y<>x) }
= let l, _, r = split t x in
join2 l r
```