Wiki Agenda Contact Version française

Maximal sum in a matrix


Authors: Jean-Christophe Filliâtre

Topics: Matrices

Tools: Why3

see also the index (by topic, by tool, by reference, by year)


(* Given a nxn matrix m of nonnegative integers, we want to pick up one element
   in each row and each column, so that their sum is maximal.

   We generalize the problem as follows: f(i,c) is the maximum for rows >= i
   and columns in set c. Thus the solution is f(0,{0,1,...,n-1}).

   f is easily defined recursively, as we have

      f(i,c) = max{j in c} m[i][j] + f(i+1, C\{j})

   As such, it would still be a brute force approach (of complexity n!)
   but we can memoize f and then the search space decreases to n*2^n.

   The following code implements such a solution. Sets of integers are
   provided in theory Bitset. Hash tables for memoization are provided
   in module HashTable (see file hash_tables.mlw for an implementation).
   Code for f is in module MaxMatrixMemo (mutually recursive functions
   maximum and memo).
*)

theory Bitset

  use int.Int

  constant size : int (* elements belong to 0..size-1 *)

  type set

  (* membership
     [mem i s] can be implemented as [s land (1 lsl i) <> 0] *)
  val predicate mem int set

  (* removal
     [remove i s] can be implemented as [s - (1 lsl i)] *)
  val function remove (x: int) (s: set): set
    ensures { forall y: int. mem y result <-> y <> x /\ mem y s }

  (* the set {0,1,...,n-1}
     [below n] can be implemented as [1 lsl n - 1] *)
  val function below (n: int): set
    requires { 0 <= n <= size }
    ensures { forall x: int. mem x result <-> 0 <= x < n }

  val function cardinal set: int

  axiom cardinal_empty:
    forall s: set. cardinal s = 0 <-> (forall x: int. not (mem x s))

  axiom cardinal_remove:
    forall x: int. forall s: set.
    mem x s -> cardinal s = 1 + cardinal (remove x s)

  axiom cardinal_below:
    forall n: int.  0 <= n <= size ->
    cardinal (below n) = if n >= 0 then n else 0

end

module HashTable

  use option.Option
  use int.Int
  use map.Map

  type t 'a 'b = private { ghost mutable contents: map 'a (option 'b) }

  function ([]) (h: t 'a 'b) (k: 'a) : option 'b = Map.get h.contents k

  val create (n:int) : t 'a 'b
    requires { 0 < n } ensures { forall k: 'a. result[k] = None }

  val clear (h: t 'a 'b) : unit writes {h}
    ensures { forall k: 'a. h[k] = None }

  val add (h: t 'a 'b) (k: 'a) (v: 'b) : unit writes {h}
    ensures { h[k] = Some v /\ forall k': 'a. k' <> k -> h[k'] = (old h)[k'] }

  exception Not_found

  val find (h: t 'a 'b) (k: 'a) : 'b
    ensures { h[k] = Some result } raises { Not_found -> h[k] = None }

end

module Appmap

  use map.Map
  use map.Const

  type key

  type t 'a = abstract { contents: map key 'a }

  val function create (x: 'a): t 'a
    ensures { result.contents = const x }

  val function ([]) (m: t 'a) (k: key): 'a
    ensures { result = m.contents[k] }

  val function ([<-]) (m: t 'a) (k: key) (v: 'a): t 'a
    ensures { result.contents = m.contents[k <- v] }

end

module Sum

  use int.Int
  use map.Map

  type container

  function f container int : int

f c i is the i-th element in the container c

  function sum container int int : int

sum c i j is the sum \sum_{i <= k < j} f c k

  axiom Sum_def_empty :
    forall c : container, i j : int. j <= i -> sum c i j = 0

  axiom Sum_def_non_empty :
    forall c: container, i j : int. i < j -> sum c i j = f c i + sum c (i+1) j

  axiom Sum_right_extension:
    forall c : container, i j : int.
    i < j -> sum c i j = sum c i (j-1) + f c (j-1)

  axiom Sum_transitivity :
    forall c : container, i k j : int. i <= k <= j ->
    sum c i j = sum c i k + sum c k j

  axiom Sum_eq :
    forall c1 c2 : container, i j : int.
    (forall k : int. i <= k < j -> f c1 k = f c2 k) -> sum c1 i j = sum c2 i j

end

module MaxMatrixMemo

  use int.Int
  use int.MinMax
  use ref.Ref
  use Bitset
  use map.Map

  clone Appmap with type key = int, axiom .

  val constant n : int
    ensures { 0 <= result <= size }

  val constant m : t (t int)
    ensures { forall i j: int. 0 <= i < n -> 0 <= j < n -> 0 <= result[i][j] }

  type mapii = Map.map int int

  predicate solution (s: mapii) (i: int) =
    (forall k: int. i <= k < n -> 0 <= Map.get s k < n) /\
    (forall k1 k2: int. i <= k1 < k2 < n -> Map.get s k1 <> Map.get s k2)

  predicate permutation (s: mapii) = solution s 0

  function f (s: mapii) (i: int) : int = m[i][Map.get s i]
  clone Sum with type container = mapii, function f = f, axiom .

  lemma sum_ind:
    forall i: int. i < n -> forall j: int.
    forall s: mapii. sum (Map.set s i j) i n = m[i][j] + sum s (i+1) n

  use option.Option
  use HashTable as H

  type key = (int, set)
  type value = (int, t int)

  predicate pre (k: key) =
    let (i, c) = k in
    0 <= i <= n /\ cardinal c = n-i /\ (forall k: int. mem k c -> 0 <= k < n)

  predicate post (k: key) (v: value) =
    let (i, c) = k in
    let (r, sol) = v in
    0 <= r /\ solution sol.contents i /\
    (forall k: int. i <= k < n -> mem sol[k] c) /\
    r = sum sol.contents i n /\
    (forall s: mapii.
       solution s i -> (forall k: int. i <= k < n -> mem (Map.get s k) c) ->
       r >= sum s i n)

  type table = H.t key value

  val table: table

  predicate inv (t: table) =
    forall k: key, v: value. H.([]) t k = Some v -> post k v

  let rec maximum (i:int) (c: set) : (int, t int) variant {2*n-2*i}
    requires { pre (i, c) /\ inv table }
    ensures { post (i,c) result /\ inv table }
  = if i = n then
      (0, create 0)
    else begin
      let r = ref (-1) in
      let sol = ref (create 0) in
      for j = 0 to n-1 do
        invariant {
          inv table /\
          (  (!r = -1 /\ forall k: int. 0 <= k < j -> not (mem k c))
          \/
            (0 <= !r /\ solution !sol.contents i /\
              (forall k: int. i <= k < n -> mem !sol[k] c) /\
              !r = sum !sol.contents i n /\
              (forall s: mapii.
                 solution s i -> (forall k: int. i <= k < n -> mem (Map.get s k) c) ->
                 mem (Map.get s i) c -> Map.get s i < j -> !r >= sum s i n)))
        }
        if mem j c then
          let (r', sol') = memo (i+1) (remove j c) in
          let x = m[i][j] + r' in
          if x > !r then begin r := x; sol := sol'[i <- j] end
      done;
      assert { 0 <= !r };
      (!r, !sol)
    end

  with memo (i:int) (c: set) : (int, t int) variant {2*n-2*i+1}
    requires { pre (i,c) /\ inv table }
    ensures { post (i,c) result /\ inv table }
  = try  H.find table (i,c)
    with H.Not_found -> let r = maximum i c in H.add table (i,c) r; r end

  let maxmat ()
    ensures { exists s: mapii. permutation s /\ result =  sum s 0 n }
    ensures { forall s: mapii. permutation s -> result >= sum s 0 n }
  = H.clear table;
    assert { inv table };
    let (r, _) = maximum 0 (below n) in r

end

download ZIP archive

Why3 Proof Results for Project "max_matrix"

Theory "max_matrix.MaxMatrixMemo": fully verified

ObligationsAlt-Ergo 2.0.0Alt-Ergo 2.3.0CVC4 1.4CVC4 1.5Z3 4.8.4
VC for n---0.00---------
VC for m---0.00---------
sum_ind0.03------------
VC for maximum---------------
split_goal_right
postcondition0.04------------
loop invariant init0.00------------
variant decrease0.01------------
precondition0.01------------
loop invariant preservation---------------
split_vc
VC for maximum---0.01---------
VC for maximum---------------
remove real,bool,string,tuple0,unit,ref,map,mapii,key,value,zero,one,( * ),min,max,const,(!),is_none,size,below,create,permutation,table1,table2,Assoc,Unit_def_l,Unit_def_r,Inv_def_l,Inv_def_r,Comm,Assoc1,Mul_distr_l,Mul_distr_r,Comm1,Unitary,NonTrivialRing,Refl,Trans,Antisymm,Total,ZeroLessOne,CompatOrderAdd,CompatOrderMult,Min_r,Max_l,Min_comm,Max_comm,Min_assoc,Max_assoc,is_none'spec,below'spec,cardinal_empty,cardinal_remove,cardinal_below,create'spec,n'def,m'def,Sum_def_empty,Sum_right_extension,Sum_transitivity,Sum_eq,H5,H7,H8,Ensures1,H10,Ensures3,Ensures4,Ensures5,Ensures6
VC for maximum1.052.25------0.08
loop invariant preservation---------------
split_vc
VC for maximum---0.01---------
VC for maximum---------------
right
right case---------------
split_vc
right case---0.02---------
right case---0.02---------
right case---0.02---------
right case---0.08---------
right case1.72------------
loop invariant preservation---------------
split_vc
VC for maximum---------------
right
VC for maximum---------------
split_goal_right
VC for maximum---------0.06---
VC for maximum0.01------------
assertion------0.03------
postcondition0.00------------
out of loop bounds0.01------------
VC for memo0.02------------
VC for maxmat0.12------------