AVL trees implemented using just the join operation

Authors: Jean-Christophe Filliâtre / Paul Patault

Topics: Data Structures / Historical examples / Trees

Just Join for Parallel Ordered Sets

Guy E. Blelloch, Daniel Ferizovic, Yihan Sun 28th ACM Symposium on Parallelism in Algorithms and Architectures, 2016

In the paper above, the authors implement various kinds of balanced binary search trees on top of a join operation. This includes the case of AVL trees, for which the authors prove that join preserves the AVL property (Lemma 1 in the paper).

In the proof below, we verify this lemma using Why3 (the AVL property, not the complexity). The paper skips the details regarding the AVL property---“The resulting tree satisfies the AVL invariants since rotations are used to restore the invariant (details left out)”---but the proof happens to be subtle. See CRITICAL below.

Authors: Jean-Christophe Filliâtre (CNRS) Paul Patault (Univ Paris-Saclay)

use int.Int
use int.MinMax

type elt

the type elt of elements, ordered with lt

val (=) (x y: elt) : bool
  ensures { result <-> x=y }

val predicate lt elt elt
clone relations.TotalStrictOrder with
  type t = elt, predicate rel = lt, axiom .

type tree = E | N int tree elt tree

the type of AVL trees, with the height stored in the first component so that we get the height in O(1) with function ht

let function ht (t: tree) : int =
  match t with E -> 0 | N h _ _ _ -> h end

let function node (l: tree) (x: elt) (r: tree) : tree =
  N (1 + max (ht l) (ht r)) l x r

let rec ghost function height (t: tree) : int
  ensures { result >= 0 }
= match t with
  | E         -> 0
  | N _ l _ r -> 1 + max (height l) (height r)

predicate wf (t: tree) =
  match t with
  | E         -> true
  | N h l _x r -> h = height t && wf l && wf r

trees are well-formed i.e. the height stored in the nodes is correct

AVL are binary search trees

predicate mem (y: elt) (t: tree) =
  match t with
  | E         -> false
  | N _ l x r -> mem y l || y=x || mem y r

predicate tree_lt (t: tree) (y: elt) =
  forall x. mem x t -> lt x y

predicate lt_tree (y: elt) (t: tree) =
  forall x. mem x t -> lt y x

predicate bst (t: tree) =
  match t with
  | E         -> true
  | N _ l x r -> bst l && tree_lt l x && bst r && lt_tree x r

predicate avl (t: tree) =
  match t with
  | E         -> true
  | N _ l _ r -> avl l && avl r && -1 <= height l - height r <= 1

AVL height invariant

Code starts here

Note: It is a pity that the specification for rotate_left and rotate_right is longer than the code, but we can't make them logical functions since they are partial functions.

let rotate_left (t: tree) : (r: tree)
  requires { wf t  } ensures { wf r  }
  requires { bst t } ensures { bst r }
  requires { match t with N _ _ _ (N _ _ _ _) -> true | _ -> false end }
  ensures  { match t with N _ a x (N _ b y c) ->
             match r with N _ (N _ ra rx rb) ry rc ->
               ra=a && rx=x && rb=b && ry=y && rc=c
             | _ -> false end | _ -> false end }
= match t with
  | N _ a x (N _ b y c) -> node (node a x b) y c
  | _ -> absurd

let rotate_right (t: tree) : (r: tree)
  requires { wf t  } ensures { wf r  }
  requires { bst t } ensures { bst r }
  requires { match t with N _ (N _ _ _ _) _ _ -> true | _ -> false end }
  ensures  { match t with N _ (N _ a x b) y c ->
             match r with N _ ra rx (N _ rb ry rc) ->
               ra=a && rx=x && rb=b && ry=y && rc=c
             | _ -> false end | _ -> false end }
= match t with
  | N _ (N _ a x b) y c -> node a x (node b y c)
  | _ -> absurd

let rec join_right (l: tree) (x: elt) (r: tree) : tree
  requires { wf l && wf r } ensures { wf result }
  requires { bst l && tree_lt l x }
  requires { bst r && lt_tree x r } ensures { bst result }
  ensures  { forall y. mem y result <-> (mem y l || y=x || mem y r) }
  requires { height l >= height r + 2 } variant { height l }
  requires { avl l && avl r } ensures { avl result }
  (* CRITICAL *)
  ensures  { height result = height l ||
             height result = height l + 1 && match result with
               | N _ rl _ rr ->
                   height rl = height l - 1 && height rr = height l
               | _ -> false end }
= match l with
  | N _ ll lx lr ->
      if ht lr <= ht r + 1 then
        let t = node lr x r in
        if ht t <= ht ll + 1 then node ll lx t
        else rotate_left (node ll lx (rotate_right t))
        let t = join_right lr x r in
        let t' = node ll lx t in
        if ht t <= ht ll + 1 then t' else rotate_left t'
        (*                                ^^^^^^^^^^^^^^
           The CRITICAL postcondition is used here
           to show that the rotated tree is indeed an AVL. *)
  | E -> absurd

let rec join_left (l: tree) (x: elt) (r: tree) : tree
  requires { wf l && wf r } ensures { wf result }
  requires { bst l && tree_lt l x }
  requires { bst r && lt_tree x r } ensures { bst result }
  ensures  { forall y. mem y result <-> (mem y l || y=x || mem y r) }
  requires { height r >= height l + 2 } variant { height r }
  requires { avl l && avl r } ensures { avl result }
  (* CRITICAL *)
  ensures  { height result = height r ||
             height result = height r + 1 && match result with
               | N _ rl _ rr ->
                   height rr = height r - 1 && height rl = height r
               | _ -> false end }
= match r with
  | N _ rl rx rr ->
      if ht rl <= ht l + 1 then
        let t = node l x rl in
        if ht t <= ht rr + 1 then node t rx rr
        else rotate_right (node (rotate_left t) rx rr)
        let t = join_left l x rl in
        let t' = node t rx rr in
        if ht t <= ht rr + 1 then t' else rotate_right t'
        (*                                ^^^^^^^^^^^^^^^ *)
  | E -> absurd

let join (l: tree) (x: elt) (r: tree) : tree
  requires { wf l && wf r } ensures { wf result }
  requires { bst l && tree_lt l x }
  requires { bst r && lt_tree x r } ensures { bst result }
  ensures  { forall y. mem y result <-> (mem y l || y=x || mem y r) }
  requires { avl l && avl r } ensures { avl result }
  ensures  { height result <= 1 + max (height l) (height r) }
=      if ht l > ht r + 1 then join_right l x r
  else if ht r > ht l + 1 then join_left  l x r
  else                         node       l x r

The remaining is much simpler.

let rec split (t: tree) (y: elt) : (l: tree, b: bool, r: tree)
  requires { wf t && bst t && avl t }
  variant  { height t }
  ensures  { wf l && bst l && avl l } ensures { tree_lt l y }
  ensures  { wf r && bst r && avl r } ensures { lt_tree y r }
  ensures  { forall x. mem x t <-> (mem x l || mem x r || b && x=y) }
= match t with
  | E -> E, false, E
  | N _ l x r ->
           if y = x then l, true, r
      else if lt y x then let ll, b, lr = split l y in ll, b, join lr x r
      else                let rl, b, rr = split r y in join l x rl, b, rr

let insert (x: elt) (t: tree) : (r: tree)
  requires { wf t && bst t && avl t }
  ensures  { wf r && bst r && avl r }
  ensures  { forall y. mem y r <-> (mem y t || y=x) }
= let l, _, r = split t x in
  join l x r

let rec split_last (t: tree) : (r: tree, m: elt)
  requires { t <> E }
  requires { wf t && bst t && avl t }
  variant  { height t }
  ensures  { wf r && bst r && avl r }
  ensures  { forall x. mem x t <-> (mem x r && lt x m || x=m) }
  ensures  { tree_lt r m }
= match t with
  | N _ l x E -> l, x
  | N _ l x r -> let r', m = split_last r in join l x r', m
  | _ -> absurd

let join2 (l r: tree) : (t: tree)
  requires { wf l && bst l && avl l }
  requires { wf r && bst r && avl r }
  requires { forall x y. mem x l -> mem y r -> lt x y }
  ensures  { wf t && bst t && avl t }
  ensures  { forall x. mem x t <-> (mem x l || mem x r) }
= match l with
  | E -> r
  | _ -> let l, k = split_last l in join l k r

let delete (x: elt) (t: tree) : (r: tree)
  requires { wf t && bst t && avl t }
  ensures  { wf r && bst r && avl r }
  ensures  { forall y. mem y r <-> (mem y t && y<>x) }
= let l, _, r = split t x in
  join2 l r

Why3 Proof Results for Project "just_join"

Theory "just_join.Top": fully verified

ObligationsAlt-Ergo 2.4.3CVC4 1.8Z3 4.12.2Z3 4.8.10
VC for height---------0.03
VC for rotate_left------0.07---
VC for rotate_right------0.02---
VC for join_right------------
variant decrease---------0.03
unreachable point---0.02------
VC for join_left------------
variant decrease---------0.02
unreachable point0.01---------
VC for join---------0.13
VC for split---------0.14
VC for insert---------0.06
VC for split_last---------0.31
VC for join2---------0.22
VC for delete---------0.05