open Num;;
open Format;;
(* ========================================================================= *)
(* Misc library functions to set up a nice environment.                      *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

let identity x = x;;

(* ------------------------------------------------------------------------- *)
(* Function composition.                                                     *)
(* ------------------------------------------------------------------------- *)

let ( ** ) = fun f g x -> f(g x);;

(* ------------------------------------------------------------------------- *)
(* GCD and LCM on arbitrary-precision numbers.                               *)
(* ------------------------------------------------------------------------- *)

let gcd_num n1 n2 =
  abs_num(num_of_big_int
      (Big_int.gcd_big_int (big_int_of_num n1) (big_int_of_num n2)));;

let lcm_num n1 n2 = abs_num(n1 */ n2) // gcd_num n1 n2;;

(* ------------------------------------------------------------------------- *)
(* A useful idiom for "non contradictory" etc.                               *)
(* ------------------------------------------------------------------------- *)

let non p x = not(p x);;

(* ------------------------------------------------------------------------- *)
(* Repetition of a function.                                                 *)
(* ------------------------------------------------------------------------- *)

let rec funpow n f x =
  if n < 1 then x else funpow (n-1) f (f x);;

let can f x = try f x; true with Failure _ -> false;;

(* ------------------------------------------------------------------------- *)
(* Handy list operations.                                                    *)
(* ------------------------------------------------------------------------- *)

let rec (--) = fun m n -> if m > n then [] else m::((m + 1) -- n);;

let rec (---) = fun m n -> if m >/ n then [] else m::((m +/ Int 1) --- n);;

let rec map2 f l1 l2 =
  match (l1,l2) with
    [],[] -> []
  | (h1::t1),(h2::t2) -> let h = f h1 h2 in h::(map2 f t1 t2)
  | _ -> failwith "map2: length mismatch";;

let rev =
  let rec rev_append acc l =
    match l with
      [] -> acc
    | h::t -> rev_append (h::acc) t in
  fun l -> rev_append [] l;;

let hd l =
  match l with
   h::t -> h
  | _ -> failwith "hd";;

let tl l =
  match l with
   h::t -> t
  | _ -> failwith "tl";;

let rec itlist f l b =
  match l with
    [] -> b
  | (h::t) -> f h (itlist f t b);;

let rec end_itlist f l =
  match l with
        []     -> failwith "end_itlist"
      | [x]    -> x
      | (h::t) -> f h (end_itlist f t);;

let rec itlist2 f l1 l2 b =
  match (l1,l2) with
    ([],[]) -> b
  | (h1::t1,h2::t2) -> f h1 h2 (itlist2 f t1 t2 b)
  | _ -> failwith "itlist2";;

let rec zip l1 l2 =
  match (l1,l2) with
        ([],[]) -> []
      | (h1::t1,h2::t2) -> (h1,h2)::(zip t1 t2)
      | _ -> failwith "zip";;

let rec forall p l =
  match l with
    [] -> true
  | h::t -> p(h) & forall p t;;

let rec exists p l =
  match l with
    [] -> false
  | h::t -> p(h) or exists p t;;

let partition p l =
    itlist (fun a (yes,no) -> if p a then a::yes,no else yes,a::no) l ([],[]);;

let filter p l = fst(partition p l);;

let length =
  let rec len k l =
    if l = [] then k else len (k + 1) (tl l) in
  fun l -> len 0 l;;

let rec last l =
  match l with
    [x] -> x
  | (h::t) -> last t
  | [] -> failwith "last";;

let rec butlast l =
  match l with
    [_] -> []
  | (h::t) -> h::(butlast t)
  | [] -> failwith "butlast";;

let rec find p l =
  match l with
      [] -> failwith "find"
    | (h::t) -> if p(h) then h else find p t;;

let rec el n l =
  if n = 0 then hd l else el (n - 1) (tl l);;

let map f =
  let rec mapf l =
    match l with
      [] -> []
    | (x::t) -> let y = f x in y::(mapf t) in
  mapf;;

let rec allpairs f l1 l2 =
  itlist (fun x -> (@) (map (f x) l2)) l1 [];;

let distinctpairs l =
  filter (fun (a,b) -> a < b) (allpairs (fun a b -> a,b) l l);;

let rec chop_list n l =
  if n = 0 then [],l else
  try let m,l' = chop_list (n-1) (tl l) in (hd l)::m,l'
  with Failure _ -> failwith "chop_list";;

let replicate n a = map (fun x -> a) (1--n);;

let rec insertat i x l =
  if i = 0 then x::l else
  match l with
    [] -> failwith "insertat: list too short for position to exist"
  | h::t -> h::(insertat (i-1) x t);;

let rec forall2 p l1 l2 =
  match (l1,l2) with
    [],[] -> true
  | (h1::t1,h2::t2) -> p h1 h2 & forall2 p t1 t2
  | _ -> false;;

let index x =
  let rec ind n l =
    match l with
      [] -> failwith "index"
    | (h::t) -> if x = h then n else ind (n + 1) t in
  ind 0;;

let rec unzip l =
  match l with
    [] -> [],[]
  | (x,y)::t ->
      let xs,ys = unzip t in x::xs,y::ys;;

(* ------------------------------------------------------------------------- *)
(* Whether the first of two items comes earlier in the list.                 *)
(* ------------------------------------------------------------------------- *)

let rec earlier l x y =
  match l with
    h::t -> if h = y then false
              else if h = x then true
              else earlier t x y
  | [] -> false;;

(* ------------------------------------------------------------------------- *)
(* Application of (presumably imperative) function over a list.              *)
(* ------------------------------------------------------------------------- *)

let rec do_list f l =
  match l with
    [] -> ()
  | h::t -> f(h); do_list f t;;

(* ------------------------------------------------------------------------- *)
(* Association lists.                                                        *)
(* ------------------------------------------------------------------------- *)

let assoc x l = snd(find (fun p -> fst p = x) l);;

let rev_assoc x l = fst(find (fun p -> snd p = x) l);;

(* ------------------------------------------------------------------------- *)
(* Merging of sorted lists (maintaining repetitions).                        *)
(* ------------------------------------------------------------------------- *)

let rec merge ord l1 l2 =
  match l1 with
    [] -> l2
  | h1::t1 -> match l2 with
                [] -> l1
              | h2::t2 -> if ord h1 h2 then h1::(merge ord t1 l2)
                          else h2::(merge ord l1 t2);;

(* ------------------------------------------------------------------------- *)
(* Bottom-up mergesort.                                                      *)
(* ------------------------------------------------------------------------- *)

let sort ord =
  let rec mergepairs l1 l2 =
    match (l1,l2) with
        ([s],[]) -> s
      | (l,[]) -> mergepairs [] l
      | (l,[s1]) -> mergepairs (s1::l) []
      | (l,(s1::s2::ss)) -> mergepairs ((merge ord s1 s2)::l) ss in
  fun l -> if l = [] then [] else mergepairs [] (map (fun x -> [x]) l);;

(* ------------------------------------------------------------------------- *)
(* Common measure predicates to use with "sort".                             *)
(* ------------------------------------------------------------------------- *)

let increasing f x y = f x < f y;;

let decreasing f x y = f x > f y;;

(* ------------------------------------------------------------------------- *)
(* Eliminate repetitions of adjacent elements, with and without counting.    *)
(* ------------------------------------------------------------------------- *)

let rec uniq l =
  match l with
    (x::(y::_ as ys)) -> if x = y then uniq ys else x::(uniq ys)
   | _ -> l;;

let repetitions =
  let rec repcount n l =
    match l with
      x::(y::_ as ys) -> if y = x then repcount (n + 1) ys
                  else (x,n)::(repcount 1 ys)
    | [x] -> [x,n] in
  fun l -> if l = [] then [] else repcount 1 l;;

let rec tryfind f l =
  match l with
      [] -> failwith "tryfind"
    | (h::t) -> try f h with Failure _ -> tryfind f t;;

let rec mapfilter f l =
  match l with
    [] -> []
  | (h::t) -> let rest = mapfilter f t in
              try (f h)::rest with Failure _ -> rest;;

(* ------------------------------------------------------------------------- *)
(* Set operations on ordered lists.                                          *)
(* ------------------------------------------------------------------------- *)

let setify =
  let rec canonical lis =
     match lis with
       x::(y::_ as rest) -> x < y & canonical rest
     | _ -> true in
  fun l -> if canonical l then l else uniq (sort (<=) l);;

let union =
  let rec union l1 l2 =
    match (l1,l2) with
        ([],l2) -> l2
      | (l1,[]) -> l1
      | ((h1::t1 as l1),(h2::t2 as l2)) ->
          if h1 = h2 then h1::(union t1 t2)
          else if h1 < h2 then h1::(union t1 l2)
          else h2::(union l1 t2) in
  fun s1 s2 -> union (setify s1) (setify s2);;

let intersect =
  let rec intersect l1 l2 =
    match (l1,l2) with
        ([],l2) -> []
      | (l1,[]) -> []
      | ((h1::t1 as l1),(h2::t2 as l2)) ->
          if h1 = h2 then h1::(intersect t1 t2)
          else if h1 < h2 then intersect t1 l2
          else intersect l1 t2 in
  fun s1 s2 -> intersect (setify s1) (setify s2);;

let subtract =
  let rec subtract l1 l2 =
    match (l1,l2) with
        ([],l2) -> []
      | (l1,[]) -> l1
      | ((h1::t1 as l1),(h2::t2 as l2)) ->
          if h1 = h2 then subtract t1 t2
          else if h1 < h2 then h1::(subtract t1 l2)
          else subtract l1 t2 in
  fun s1 s2 -> subtract (setify s1) (setify s2);;

let subset,psubset =
  let rec subset l1 l2 =
    match (l1,l2) with
        ([],l2) -> true
      | (l1,[]) -> false
      | ((h1::t1 as l1),(h2::t2 as l2)) ->
          if h1 = h2 then subset t1 t2
          else if h1 < h2 then false
          else subset l1 t2
  and psubset l1 l2 =
    match (l1,l2) with
        (l1,[]) -> false
      | ([],l2) -> true
      | ((h1::t1 as l1),(h2::t2 as l2)) ->
          if h1 = h2 then psubset t1 t2
          else if h1 < h2 then false
          else subset l1 t2 in
  (fun s1 s2 -> subset (setify s1) (setify s2)),
  (fun s1 s2 -> psubset (setify s1) (setify s2));;

let rec set_eq s1 s2 = (setify s1 = setify s2);;

let insert x s = union [x] s;;

let smap f s = setify (map f s);;

(* ------------------------------------------------------------------------- *)
(* Union of a family of sets.                                                *)
(* ------------------------------------------------------------------------- *)

let unions s = setify(itlist (@) s []);;

(* ------------------------------------------------------------------------- *)
(* List membership. This does *not* assume the list is a set.                *)
(* ------------------------------------------------------------------------- *)

let rec mem x lis =
  match lis with
    [] -> false
  | (h::t) -> x = h or mem x t;;

(* ------------------------------------------------------------------------- *)
(* Finding all subsets or all subsets of a given size.                       *)
(* ------------------------------------------------------------------------- *)

let rec allsets m l =
  if m = 0 then [[]] else
  match l with
    [] -> []
  | h::t -> map (fun g -> h::g) (allsets (m - 1) t) @ allsets m t;;

let rec allsubsets s =
  match s with
    [] -> [[]]
  | (a::t) -> let res = allsubsets t in
              map (fun b -> a::b) res @ res;;

let allnonemptysubsets s = subtract (allsubsets s) [[]];;

(* ------------------------------------------------------------------------- *)
(* Explosion and implosion of strings.                                       *)
(* ------------------------------------------------------------------------- *)

let explode s =
  let rec exap n l =
     if n < 0 then l else
      exap (n - 1) ((String.sub s n 1)::l) in
  exap (String.length s - 1) [];;

let implode l = itlist (^) l "";;

(* ------------------------------------------------------------------------- *)
(* Timing; useful for documentation but not logically necessary.             *)
(* ------------------------------------------------------------------------- *)

let time f x =
  let start_time = Sys.time() in
  let result = f x in
  let finish_time = Sys.time() in
  print_string
    ("CPU time (user): "^(string_of_float(finish_time -. start_time)));
  print_newline();
  result;;

(* ------------------------------------------------------------------------- *)
(* Representation of finite partial functions as balanced trees.             *)
(* Alas, there's no polymorphic one available in the standard library.       *)
(* So this is basically a copy of what's there.                              *)
(* ------------------------------------------------------------------------- *)

type ('a,'b)func =
    Empty
  | Node of ('a,'b)func * 'a * 'b * ('a,'b)func * int;;

let apply,undefined,(|->),undefine,dom,funset =
  let compare x y = if x = y then 0 else if x < y then -1 else 1 in
  let empty = Empty in
  let height = function
      Empty -> 0
    | Node(_,_,_,_,h) -> h in
  let create l x d r =
    let hl = height l and hr = height r in
    Node(l, x, d, r, (if hl >= hr then hl + 1 else hr + 1)) in
  let bal l x d r =
    let hl = match l with Empty -> 0 | Node(_,_,_,_,h) -> h in
    let hr = match r with Empty -> 0 | Node(_,_,_,_,h) -> h in
    if hl > hr + 2 then begin
      match l with
        Empty -> invalid_arg "Map.bal"
      | Node(ll, lv, ld, lr, _) ->
          if height ll >= height lr then
            create ll lv ld (create lr x d r)
          else begin
            match lr with
              Empty -> invalid_arg "Map.bal"
            | Node(lrl, lrv, lrd, lrr, _)->
                create (create ll lv ld lrl) lrv lrd (create lrr x d r)
          end
    end else if hr > hl + 2 then begin
      match r with
        Empty -> invalid_arg "Map.bal"
      | Node(rl, rv, rd, rr, _) ->
          if height rr >= height rl then
            create (create l x d rl) rv rd rr
          else begin
            match rl with
              Empty -> invalid_arg "Map.bal"
            | Node(rll, rlv, rld, rlr, _) ->
                create (create l x d rll) rlv rld (create rlr rv rd rr)
          end
    end else
      Node(l, x, d, r, (if hl >= hr then hl + 1 else hr + 1)) in
  let rec add x data = function
      Empty ->
        Node(Empty, x, data, Empty, 1)
    | Node(l, v, d, r, h) as t ->
        let c = compare x v in
        if c = 0 then
          Node(l, x, data, r, h)
        else if c < 0 then
          bal (add x data l) v d r
        else
          bal l v d (add x data r) in
  let rec find x = function
      Empty ->
        raise Not_found
    | Node(l, v, d, r, _) ->
        let c = compare x v in
        if c = 0 then d
        else find x (if c < 0 then l else r) in
  let rec mem x = function
      Empty ->
        false
    | Node(l, v, d, r, _) ->
        let c = compare x v in
        c = 0 or mem x (if c < 0 then l else r) in
  let rec merge t1 t2 =
    match (t1, t2) with
      (Empty, t) -> t
    | (t, Empty) -> t
    | (Node(l1, v1, d1, r1, h1), Node(l2, v2, d2, r2, h2)) ->
        bal l1 v1 d1 (bal (merge r1 l2) v2 d2 r2) in
  let rec remove x = function
      Empty ->
        Empty
    | Node(l, v, d, r, h) as t ->
        let c = compare x v in
        if c = 0 then
          merge l r
        else if c < 0 then
          bal (remove x l) v d r
        else
          bal l v d (remove x r) in
  let rec iter f = function
      Empty -> ()
    | Node(l, v, d, r, _) ->
        iter f l; f v d; iter f r in
  let rec map f = function
      Empty               -> Empty
    | Node(l, v, d, r, h) -> Node(map f l, v, f d, map f r, h) in
  let rec mapi f = function
      Empty               -> Empty
    | Node(l, v, d, r, h) -> Node(mapi f l, v, f v d, mapi f r, h) in
  let rec fold f m accu =
    match m with
      Empty -> accu
    | Node(l, v, d, r, _) ->
        fold f l (f v d (fold f r accu)) in
  let apply f x = try find x f with Not_found -> failwith "apply" in
  let undefined = Empty in
  let valmod x y f = add x y f in
  let undefine a f = remove a f in
  let dom f = setify(fold (fun x y a -> x::a) f []) in
  let funset f = setify(fold (fun x y a -> (x,y)::a) f []) in
apply,undefined,valmod,undefine,dom,funset;;

let tryapplyd f a d = try apply f a with Failure _ -> d;;
let tryapply f x = tryapplyd f x x;;
let tryapplyl f x = tryapplyd f x [];;
let (:=) = fun x y -> (x |-> y) undefined;;
let fpf assigs = itlist (fun x -> x) assigs undefined;;
let defined f x = can (apply f) x;;

(* ------------------------------------------------------------------------- *)
(* Install a (trivial) printer for finite partial functions.                 *)
(* ------------------------------------------------------------------------- *)

let print_fpf (f:('a,'b)func) = print_string "<func>";;

(*
#install_printer print_fpf;;
*)

(* ------------------------------------------------------------------------- *)
(* Related stuff for standard functions.                                     *)
(* ------------------------------------------------------------------------- *)

let valmod a y f x = if x = a then y else f(x);;

let undef x = failwith "undefined function";;

(* ------------------------------------------------------------------------- *)
(* Union-find algorithm.                                                     *)
(* ------------------------------------------------------------------------- *)

type ('a)pnode = Nonterminal of 'a | Terminal of 'a * int;;

type ('a)partition = Partition of ('a,('a)pnode)func;;

let rec terminus (Partition f as ptn) a =
  match (apply f a) with
    Nonterminal(b) -> terminus ptn b
  | Terminal(p,q) -> (p,q);;

let tryterminus ptn a =
  try terminus ptn a with Failure _ -> (a,0);;

let canonize ptn a = try fst(terminus ptn a) with Failure _ -> a;;

let equate (a,b) (Partition f as ptn) =
  let (a',na) = tryterminus ptn a
  and (b',nb) = tryterminus ptn b in
  Partition
   (if a' = b' then f else
    if na <= nb then
       itlist identity [a' |-> Nonterminal b'; b' |-> Terminal(b',na+nb)] f
    else
       itlist identity [b' |-> Nonterminal a'; a' |-> Terminal(a',na+nb)] f);;

let unequal = Partition undefined;;

let equated (Partition f) = dom f;;

(* ------------------------------------------------------------------------- *)
(* First number starting at n for which p succeeds.                          *)
(* ------------------------------------------------------------------------- *)

let rec first n p = if p(n) then n else first (n +/ Int 1) p;;
(* ========================================================================= *)
(* Simple algebraic expression example from the introductory chapter.        *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

type expression =
   Var of string
 | Const of int
 | Add of expression * expression
 | Mul of expression * expression;;

(* ------------------------------------------------------------------------- *)
(* Trivial example of using the type constructors.                           *)
(* ------------------------------------------------------------------------- *)

(*
Add(Mul(Const 2,Var "x"),Var "y");;
*)

(* ------------------------------------------------------------------------- *)
(* Simplification example.                                                   *)
(* ------------------------------------------------------------------------- *)

let simplify1 expr =
  match expr with
    Add(Const(m),Const(n)) -> Const(m + n)
  | Mul(Const(m),Const(n)) -> Const(m * n)
  | Add(Const(0),x) -> x
  | Add(x,Const(0)) -> x
  | Mul(Const(0),x) -> Const(0)
  | Mul(x,Const(0)) -> Const(0)
  | Mul(Const(1),x) -> x
  | Mul(x,Const(1)) -> x
  | _ -> expr;;

let rec simplify expr =
  match expr with
    Add(e1,e2) -> simplify1(Add(simplify e1,simplify e2))
  | Mul(e1,e2) -> simplify1(Mul(simplify e1,simplify e2))
  | _ -> expr;;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)
(*
let e = Add(Mul(Add(Mul(Const(0),Var "x"),Const(1)),Const(3)),
            Const(12));;
simplify e;;
*)

(* ------------------------------------------------------------------------- *)
(* Lexical analysis.                                                         *)
(* ------------------------------------------------------------------------- *)

let matches s = let chars = explode s in fun c -> mem c chars;;

let space = matches " \t\n"
and punctuation = matches "()[]{},"
and symbolic = matches "~`!@#$%^&*-+=|\\:;<>.?/"
and numeric = matches "0123456789"
and alphanumeric = matches
  "abcdefghijklmnopqrstuvwxyz_'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";;

let rec lexwhile prop inp =
  match inp with
    [] -> "",[]
  | c::cs ->
        if prop c then let tok,rest = lexwhile prop cs in c^tok,rest
        else "",inp;;

let rec lex inp =
  let _,inp1 = lexwhile space inp in
  match inp1 with
    [] -> []
  | c::cs -> let prop =
               if alphanumeric(c) then alphanumeric
               else if symbolic(c) then symbolic
               else if punctuation(c) then (fun c -> false)
               else failwith "Unknown character in input" in
             let toktl,rest = lexwhile prop cs in
             (c^toktl)::lex rest;;

(*
lex(explode "2*((var_1 + x') + 11)");;
lex(explode "if (*p1-- == *p2++) then f() else g()");;
*)

(* ------------------------------------------------------------------------- *)
(* Parsing.                                                                  *)
(* ------------------------------------------------------------------------- *)

let rec parse_expression inp =
  let e1,inp1 = parse_product inp in
  if inp1 <> [] & hd inp1 = "+" then
     let e2,inp2 = parse_expression (tl inp1) in Add(e1,e2),inp2
  else e1,inp1

and parse_product inp =
  let e1,inp1 = parse_atom inp in
  if inp1 <> [] & hd inp1 = "*" then
     let e2,inp2 = parse_product (tl inp1) in Mul(e1,e2),inp2
  else e1,inp1

and parse_atom inp =
  match inp with
    [] -> failwith "Expected an expression at end of input"
  | tok::toks ->
        if tok = "(" then
           let e,inp1 = parse_expression toks in
           if inp1 <> [] & hd inp1 = ")" then e,tl inp1
           else failwith "Expected closing bracket"
        else if forall numeric (explode tok) then
           Const(int_of_string tok),toks
        else Var(tok),toks;;

(* ------------------------------------------------------------------------- *)
(* Generic function to impose lexing and exhaustion checking on a parser.    *)
(* ------------------------------------------------------------------------- *)

let make_parser pfn s =
  let expr,rest = pfn (lex(explode s)) in
  if rest = [] then expr else failwith "Unparsed input";;

(* ------------------------------------------------------------------------- *)
(* Our parser.                                                               *)
(* ------------------------------------------------------------------------- *)

let parsee = make_parser parse_expression;;

(* ------------------------------------------------------------------------- *)
(* Examples.                                                                 *)
(* ------------------------------------------------------------------------- *)

(*
parsee "x + 1";;

parsee "(x1 + x2 + x3) * (1 + 2 + 3 * x + y)";;

*)

(* ------------------------------------------------------------------------- *)
(* Conservatively bracketing first attempt at printer.                       *)
(* ------------------------------------------------------------------------- *)

let rec string_of_exp e =
  match e with
    Var s -> s
  | Const n -> string_of_int n
  | Add(e1,e2) -> "("^(string_of_exp e1)^" + "^(string_of_exp e2)^")"
  | Mul(e1,e2) -> "("^(string_of_exp e1)^" * "^(string_of_exp e2)^")";;

(* ------------------------------------------------------------------------- *)
(* Examples.                                                                 *)
(* ------------------------------------------------------------------------- *)

(*
string_of_exp(parsee "x + 3 * y");;
let e = parsee "3 * (y + z) + 7 * 4";;
parsee(string_of_exp e) = e;;
*)

(* ------------------------------------------------------------------------- *)
(* Somewhat better attempt.                                                  *)
(* ------------------------------------------------------------------------- *)

let rec string_of_exp pr e =
  match e with
    Var s -> s
  | Const n -> string_of_int n
  | Add(e1,e2) ->
        let s = (string_of_exp 3 e1)^" + "^(string_of_exp 2 e2) in
        if pr > 2 then "("^s^")" else s
  | Mul(e1,e2) ->
        let s = (string_of_exp 5 e1)^" * "^(string_of_exp 4 e2) in
        if pr > 4 then "("^s^")" else s;;

(* ------------------------------------------------------------------------- *)
(* Examples.                                                                 *)
(* ------------------------------------------------------------------------- *)

(*
string_of_exp 0 (parsee "x + 3 * y");;
string_of_exp 0 (parsee "(x + 3) * y");;
string_of_exp 0 (parsee "1 + 2 + 3");;
string_of_exp 0 (parsee "((1 + 2) + 3) + 4");;
*)

(* ------------------------------------------------------------------------- *)
(* Example shows the problem.                                                *)
(* ------------------------------------------------------------------------- *)

(*
let e = parsee "(x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9 + x10) *
                (y1 + y2 + y3 + y4 + y5 + y6 + y7 + y8 + y9 + y10)";;

string_of_exp 0 e;;
*)

(* ------------------------------------------------------------------------- *)
(* Real printer with proper line breaks.                                     *)
(* ------------------------------------------------------------------------- *)

let rec print_exp pr e =
  match e with
    Var s -> print_string s
  | Const n -> print_int n
  | Add(e1,e2) ->
        if pr > 2 then (print_string "("; open_box 0) else ();
        print_exp 3 e1;
        print_string " +"; print_space();
        print_exp 2 e2;
        if pr > 2 then (close_box(); print_string ")") else ()
  | Mul(e1,e2) ->
        if pr > 4 then (print_string "("; open_box 0) else ();
        print_exp 5 e1;
        print_string " *"; print_space();
        print_exp 4 e2;
        if pr > 4 then (close_box(); print_string ")") else ();;

let print_expression e =
  open_box 0; print_string "<<";
              open_box 0; print_exp 0 e; close_box();
  print_string ">>"; close_box();;

(* ------------------------------------------------------------------------- *)
(* Also set up parsing of quotations.                                        *)
(* ------------------------------------------------------------------------- *)

let default_parser = parsee;;

(* ------------------------------------------------------------------------- *)
(* Examples.                                                                 *)
(* ------------------------------------------------------------------------- *)

(*
print_expression(Mul(Const 3,Add(Mul(e,e),Mul(e,e))));;

#install_printer print_expression;;

parsee "3 + x * y";;

<<3 + x * y>>;;

*)
(* ========================================================================= *)
(* Polymorphic type of formulas with parser and printer.                     *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

type ('a)formula = False
                 | True
                 | Atom of 'a
                 | Not of ('a)formula
                 | And of ('a)formula * ('a)formula
                 | Or of ('a)formula * ('a)formula
                 | Imp of ('a)formula * ('a)formula
                 | Iff of ('a)formula * ('a)formula
                 | Forall of string * ('a)formula
                 | Exists of string * ('a)formula;;

(* ------------------------------------------------------------------------- *)
(* General homomorphism and iteration functions for atoms in formula.        *)
(* ------------------------------------------------------------------------- *)

let rec onatoms fn fm =
  match fm with
    Atom(a) -> fn a
  | Not(p) -> Not(onatoms fn p)
  | And(p,q) -> And(onatoms fn p,onatoms fn q)
  | Or(p,q) -> Or(onatoms fn p,onatoms fn q)
  | Imp(p,q) -> Imp(onatoms fn p,onatoms fn q)
  | Iff(p,q) -> Iff(onatoms fn p,onatoms fn q)
  | Forall(x,p) -> Forall(x,onatoms fn p)
  | Exists(x,p) -> Exists(x,onatoms fn p)
  | _ -> fm;;

let rec overatoms f fm b =
  match fm with
    Atom(a) -> f a b
  | Not(p) -> overatoms f p b
  | And(p,q) | Or(p,q) | Imp(p,q) | Iff(p,q) ->
        overatoms f p (overatoms f q b)
  | Forall(x,p) | Exists(x,p) -> overatoms f p b
  | _ -> b;;

(* ------------------------------------------------------------------------- *)
(* Special case of a union of the results of a function over the atoms.      *)
(* ------------------------------------------------------------------------- *)

let atom_union f fm = setify (overatoms (fun h t -> f(h)@t) fm []);;

(* ------------------------------------------------------------------------- *)
(* General parsing of iterated infixes.                                      *)
(* ------------------------------------------------------------------------- *)

let rec parse_ginfix opsym opupdate sof subparser inp =
  let e1,inp1 = subparser inp in
  if inp1 <> [] & hd inp1 = opsym then
     parse_ginfix opsym opupdate (opupdate sof e1) subparser (tl inp1)
  else sof e1,inp1;;

let parse_left_infix opsym opcon =
  parse_ginfix opsym (fun f e1 e2 -> opcon(f e1,e2)) (fun x -> x);;

let parse_right_infix opsym opcon =
  parse_ginfix opsym (fun f e1 e2 -> f(opcon(e1,e2))) (fun x -> x);;

let parse_list opsym =
  parse_ginfix opsym (fun f e1 e2 -> (f e1)@[e2]) (fun x -> [x]);;

(* ------------------------------------------------------------------------- *)
(* Other general parsing combinators.                                        *)
(* ------------------------------------------------------------------------- *)

let papply f (ast,rest) = (f ast,rest);;

let nextin inp tok = inp <> [] & hd inp = tok;;

let parse_bracketed subparser cbra inp =
  let ast,rest = subparser inp in
  if nextin rest cbra then ast,tl rest
  else failwith "Closing bracket expected";;

(* ------------------------------------------------------------------------- *)
(* Parsing of formulas, parametrized by atom parser "pfn".                   *)
(* ------------------------------------------------------------------------- *)

let rec parse_atomic_formula pfn vs inp =
  match inp with
    [] -> failwith "formula expected"
  | "false"::rest -> False,rest
  | "true"::rest -> True,rest
  | "("::rest -> (try pfn vs inp with Failure _ ->
                  parse_bracketed (parse_formula pfn vs) ")" rest)
  | "~"::rest -> papply (fun p -> Not p)
                        (parse_atomic_formula pfn vs rest)
  | "forall"::x::rest ->
        parse_quant pfn (x::vs) (fun (x,p) -> Forall(x,p)) x rest
  | "exists"::x::rest ->
        parse_quant pfn (x::vs) (fun (x,p) -> Exists(x,p)) x rest
  | _ -> pfn vs inp

and parse_quant pfn vs qcon x inp =
   match inp with
     [] -> failwith "Body of quantified term expected"
   | y::rest ->
        papply (fun fm -> qcon(x,fm))
               (if y = "." then parse_formula pfn vs rest
                else parse_quant pfn (y::vs) qcon y rest)

and parse_formula pfn vs inp =
   parse_right_infix "<=>" (fun (p,q) -> Iff(p,q))
     (parse_right_infix "==>" (fun (p,q) -> Imp(p,q))
         (parse_right_infix "\\/" (fun (p,q) -> Or(p,q))
             (parse_right_infix "/\\" (fun (p,q) -> And(p,q))
                  (parse_atomic_formula pfn vs)))) inp;;

(* ------------------------------------------------------------------------- *)
(* Printing of formulas, parametrized by atom printer.                       *)
(* ------------------------------------------------------------------------- *)

let rec strip_quant isforall fm =
  match (fm,isforall) with
    Forall(x,p),true -> papply (fun l -> x::l) (strip_quant isforall p)
  | Exists(x,p),false -> papply (fun l -> x::l) (strip_quant isforall p)
  | _ -> [],fm;;

let rec print_formula pfn prec fm =
  match fm with
    False -> print_string "false"
  | True -> print_string "true"
  | Atom(pargs) -> pfn prec pargs
  | Not(p) -> print_string "~"; print_formula pfn 10 p
  | And(p,q) -> print_infix_formula pfn prec 8 "/\\" p q
  | Or(p,q) -> print_infix_formula pfn prec 6 "\\/" p q
  | Imp(p,q) -> print_infix_formula pfn prec 4 "==>" p q
  | Iff(p,q) -> print_infix_formula pfn prec 2 "<=>" p q
  | Forall(x,p) -> print_quant pfn prec "forall" (strip_quant true fm)
  | Exists(x,p) -> print_quant pfn prec "exists" (strip_quant false fm)

and print_quant pfn prec qname (bvs,bod) =
  if prec <> 0 then print_string "(" else ();
  print_string qname;
  do_list (fun v -> print_string " "; print_string v) bvs;
  print_string ". "; open_box 0;
  print_formula pfn 0 bod;
  close_box();
  if prec <> 0 then print_string ")" else ()

and print_infix_formula pfn oldprec newprec sym p q =
  if oldprec > newprec then (print_string "("; open_box 0) else ();
  print_formula pfn (newprec+1) p;
  print_string(" "^sym); print_space();
  print_formula pfn newprec q;
  if oldprec > newprec then (close_box(); print_string ")") else ();;

let formula_printer pfn fm =
  open_box 0; print_string "<<";
  open_box 0; print_formula pfn 0 fm; close_box();
  print_string ">>"; close_box();;
(* ========================================================================= *)
(* Basic stuff for propositional logic: datatype, parsing and printing.      *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

type prop = P of string;;

let pname(P s) = s;;

(* ------------------------------------------------------------------------- *)
(* Parsing of propositional formulas.                                        *)
(* ------------------------------------------------------------------------- *)

let parse_propvar vs inp =
  match inp with
    p::oinp when p <> "(" -> Atom(P(p)),oinp
  | _ -> failwith "parse_propvar";;

let parsep = make_parser (parse_formula parse_propvar []);;

(* ------------------------------------------------------------------------- *)
(* Set this up as default for quotations.                                    *)
(* ------------------------------------------------------------------------- *)

let default_parser = parsep;;

(* ------------------------------------------------------------------------- *)
(* Test of the parser.                                                       *)
(* ------------------------------------------------------------------------- *)

(*
let fm = parsep "p ==> q <=> r /\ s \/ (t <=> ~ ~u /\ v)";;

let fm = <<p ==> q <=> r /\ s \/ (t <=> ~ ~u /\ v)>>;;
*)

(* ------------------------------------------------------------------------- *)
(* Printer.                                                                  *)
(* ------------------------------------------------------------------------- *)

let print_propvar prec p = print_string(pname p);;

let pr = formula_printer print_propvar;;

(*
#install_printer pr;;
*)

(* ------------------------------------------------------------------------- *)
(* Testing the printer.                                                      *)
(* ------------------------------------------------------------------------- *)

(*
And(fm,fm);;

And(Or(fm,fm),fm);;
*)

(* ------------------------------------------------------------------------- *)
(* Interpretation of formulas.                                               *)
(* ------------------------------------------------------------------------- *)

let rec eval fm v =
  match fm with
    False -> false
  | True -> true
  | Atom(x) -> v(x)
  | Not(p) -> not(eval p v)
  | And(p,q) -> (eval p v) & (eval q v)
  | Or(p,q) -> (eval p v) or (eval q v)
  | Imp(p,q) -> not(eval p v) or (eval q v)
  | Iff(p,q) -> (eval p v) = (eval q v);;

(* ------------------------------------------------------------------------- *)
(* Example of how we could define connective interpretations ourselves.      *)
(* ------------------------------------------------------------------------- *)

let (-->) p q = match (p,q) with (true,false) -> false | _ -> true;;

let rec eval fm v =
  match fm with
    False -> false
  | True -> true
  | Atom(x) -> v(x)
  | Not(p) -> not(eval p v)
  | And(p,q) -> (eval p v) & (eval q v)
  | Or(p,q) -> (eval p v) or (eval q v)
  | Imp(p,q) -> eval p v --> eval q v
  | Iff(p,q) -> (eval p v) = (eval q v);;

(* ------------------------------------------------------------------------- *)
(* Example of use, showing the "partial" evaluation.                         *)
(* ------------------------------------------------------------------------- *)

(*
let fm = <<p /\ q ==> q /\ r>>;;

let fm_interp = eval fm;;

eval fm (function P"p" -> true | P"q" -> false | P"r" -> true);;

eval fm (function P"p" -> true | P"q" -> true | P"r" -> false);;
*)

(* ------------------------------------------------------------------------- *)
(* Return the set of propositional variables in a formula.                   *)
(* ------------------------------------------------------------------------- *)

let atoms fm = atom_union (fun a -> [a]) fm;;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)

(*
atoms <<p /\ q \/ s ==> ~p \/ (r <=> s)>>;;
*)

(* ------------------------------------------------------------------------- *)
(* Code to print out truth tables.                                           *)
(* ------------------------------------------------------------------------- *)

let rec onallvaluations subfn v pvs =
  match pvs with
    [] -> subfn v
  | p::ps -> let v' t q = if q = p then t else v(q) in
             onallvaluations subfn (v' false) ps &
             onallvaluations subfn (v' true) ps;;

let print_truthtable fm =
  let pvs = atoms fm in
  let width = itlist (max ** String.length ** pname) pvs 5 + 1 in
  let fixw s = s^String.make(width - String.length s) ' ' in
  let truthstring p = fixw (if p then "true" else "false") in
  let mk_row v =
     let lis = map (fun x -> truthstring(v x)) pvs
     and ans = truthstring(eval fm v) in
     print_string(itlist (^) lis ("| "^ans)); print_newline();
     true in
  let separator = String.make (width * length pvs + 9) '-' in
  print_string(itlist (fun s t -> fixw(pname s) ^ t) pvs "| formula");
  print_newline(); print_string separator; print_newline();
  onallvaluations mk_row (fun x -> false) pvs;
  print_string separator; print_newline();;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)

(*
let fm = <<p /\ q ==> q /\ r>>;;

print_truthtable fm;;
*)

(* ------------------------------------------------------------------------- *)
(* Additional examples illustrating formula classes.                         *)
(* ------------------------------------------------------------------------- *)

(*
print_truthtable <<((p ==> q) ==> p) ==> p>>;;

print_truthtable <<p /\ ~p>>;;
*)

(* ------------------------------------------------------------------------- *)
(* Recognizing tautologies.                                                  *)
(* ------------------------------------------------------------------------- *)

let tautology fm =
  onallvaluations (eval fm) (fun s -> false) (atoms fm);;

(* ------------------------------------------------------------------------- *)
(* Examples.                                                                 *)
(* ------------------------------------------------------------------------- *)

(*
tautology <<p \/ ~p>>;;
tautology <<p \/ q ==> p>>;;
tautology <<p \/ q ==> q \/ (p <=> q)>>;;
tautology <<(p \/ q) /\ ~(p /\ q) ==> (~p <=> q)>>;;
*)

(* ------------------------------------------------------------------------- *)
(* Related concepts.                                                         *)
(* ------------------------------------------------------------------------- *)

let unsatisfiable fm = tautology(Not fm);;

let satisfiable fm = not(unsatisfiable fm);;

(* ------------------------------------------------------------------------- *)
(* Substitution operation.                                                   *)
(* ------------------------------------------------------------------------- *)

let propsubst subfn = onatoms (fun p -> tryapplyd subfn p (Atom p));;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)

(*
let pandq = <<p /\ q>>;;

propsubst (P"p" := pandq) pandq;;
*)

(* ------------------------------------------------------------------------- *)
(* Surprising tautologies including Dijkstra's "Golden rule".                *)
(* ------------------------------------------------------------------------- *)

(*
tautology <<(p ==> q) \/ (q ==> p)>>;;

tautology <<p \/ (q <=> r) <=> (p \/ q <=> p \/ r)>>;;

tautology <<p /\ q <=> ((p <=> q) <=> p \/ q)>>;;

(* ------------------------------------------------------------------------- *)
(* Some logical equivalences allowing elimination of connectives.            *)
(* ------------------------------------------------------------------------- *)

tautology <<false <=> p /\ ~p>>;;
tautology <<true <=> ~(p /\ ~p)>>;;
tautology <<p \/ q <=> ~(~p /\ ~q)>>;;
tautology <<p ==> q <=> ~(p /\ ~q)>>;;
tautology <<(p <=> q) <=> ~(p /\ ~q) /\ ~(~p /\ q)>>;;

tautology <<true <=> false ==> false>>;;
tautology <<~p <=> p ==> false>>;;
tautology <<p /\ q <=> (p ==> q ==> false) ==> false>>;;
tautology <<p \/ q <=> (p ==> false) ==> q>>;;
tautology(parsep
  "(p <=> q) <=> ((p ==> q) ==> (q ==> p) ==> false) ==> false");;
*)

(* ------------------------------------------------------------------------- *)
(* Dualization.                                                              *)
(* ------------------------------------------------------------------------- *)

let rec subdualize fm =
  match fm with
    False -> True
  | True -> False
  | Atom(p) -> fm
  | Not(p) -> Not(subdualize p)
  | And(p,q) -> Or(subdualize p,subdualize q)
  | Or(p,q) -> And(subdualize p,subdualize q)
  | _ -> failwith "Formula involves connectives ==> and <=>";;

let dualize fm = Not(subdualize fm);;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)

(*
dualize <<p \/ ~p>>;;
*)

(* ------------------------------------------------------------------------- *)
(* Routine simplification.                                                   *)
(* ------------------------------------------------------------------------- *)

let psimplify1 fm =
  match fm with
    Not False -> True
  | Not True -> False
  | And(False,q) -> False
  | And(p,False) -> False
  | And(True,q) -> q
  | And(p,True) -> p
  | Or(False,q) -> q
  | Or(p,False) -> p
  | Or(True,q) -> True
  | Or(p,True) -> True
  | Imp(False,q) -> True
  | Imp(True,q) -> q
  | Imp(p,True) -> True
  | Imp(p,False) -> Not p
  | Iff(True,q) -> q
  | Iff(p,True) -> p
  | Iff(False,q) -> Not q
  | Iff(p,False) -> Not p
  | _ -> fm;;

let rec psimplify fm =
  match fm with
  | Not p -> psimplify1 (Not(psimplify p))
  | And(p,q) -> psimplify1 (And(psimplify p,psimplify q))
  | Or(p,q) -> psimplify1 (Or(psimplify p,psimplify q))
  | Imp(p,q) -> psimplify1 (Imp(psimplify p,psimplify q))
  | Iff(p,q) -> psimplify1 (Iff(psimplify p,psimplify q))
  | _ -> fm;;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)

(*
psimplify <<(true ==> (x <=> false)) ==> ~(y \/ false /\ z)>>;;

psimplify <<((x ==> y) ==> true) \/ ~false>>;;
*)

(* ------------------------------------------------------------------------- *)
(* Negation normal form.                                                     *)
(* ------------------------------------------------------------------------- *)

let rec nnf fm =
  match fm with
  | And(p,q) -> And(nnf p,nnf q)
  | Or(p,q) -> Or(nnf p,nnf q)
  | Imp(p,q) -> Or(nnf(Not p),nnf q)
  | Iff(p,q) -> Or(And(nnf p,nnf q),And(nnf(Not p),nnf(Not q)))
  | Not(Not p) -> nnf p
  | Not(And(p,q)) -> Or(nnf(Not p),nnf(Not q))
  | Not(Or(p,q)) -> And(nnf(Not p),nnf(Not q))
  | Not(Imp(p,q)) -> And(nnf p,nnf(Not q))
  | Not(Iff(p,q)) -> Or(And(nnf p,nnf(Not q)),And(nnf(Not p),nnf q))
  | _ -> fm;;

(* ------------------------------------------------------------------------- *)
(* Side remark on possible alternative tautology.                            *)
(* ------------------------------------------------------------------------- *)

(*
tautology <<(p <=> q) <=> (p \/ ~q) /\ (~p \/ q)>>;;

(* ------------------------------------------------------------------------- *)
(* Example of NNF function in action.                                        *)
(* ------------------------------------------------------------------------- *)

let fm = <<(p <=> q) <=> ~(r ==> s)>>;;

let fm' = nnf fm;;

tautology(Iff(fm,fm'));;
*)

(* ------------------------------------------------------------------------- *)
(* More efficient version.                                                   *)
(* ------------------------------------------------------------------------- *)

let rec nnfp fm =
  match fm with
  | Not(p) ->
        let p',p'' = nnfp p in
        p'',p'
  | And(p,q) ->
        let p',p'' = nnfp p and q',q'' = nnfp q in
        And(p',q'),Or(p'',q'')
  | Or(p,q) ->
        let p',p'' = nnfp p and q',q'' = nnfp q in
        Or(p',q'),And(p'',q'')
  | Imp(p,q) ->
        let p',p'' = nnfp p and q',q'' = nnfp q in
        Or(p'',q'),And(p',q'')
  | Iff(p,q) ->
        let p',p'' = nnfp p and q',q'' = nnfp q in
        Or(And(p',q'),And(p'',q'')),Or(And(p',q''),And(p'',q'))
  | _ -> fm,Not fm;;

let nnf fm = fst(nnfp(psimplify fm));;

(* ------------------------------------------------------------------------- *)
(* Some tautologies remarked on.                                             *)
(* ------------------------------------------------------------------------- *)

(*
tautology
 <<(p ==> p') /\ (q ==> q') ==> (p \/ q ==> p' \/ q')>>;;

tautology
 <<(p ==> p') /\ (q ==> q') ==> (p /\ q ==> p' /\ q')>>;;
*)

(* ------------------------------------------------------------------------- *)
(* Tracking positive and negative occurrences.                               *)
(* ------------------------------------------------------------------------- *)

let rec occurrences x fm =
  match fm with
    Atom(y) ->
        (x = y,false)
  | Not(p) ->
        let pos,neg = occurrences x p in neg,pos
  | And(p,q) ->
        let pos1,neg1 = occurrences x p
        and pos2,neg2 = occurrences x q in
        (pos1 or pos2,neg1 or neg2)
  | Or(p,q) ->
        let pos1,neg1 = occurrences x p
        and pos2,neg2 = occurrences x q in
        (pos1 or pos2,neg1 or neg2)
  | Imp(p,q) ->
        let pos1,neg1 = occurrences x p
        and pos2,neg2 = occurrences x q in
        (neg1 or pos2,pos1 or neg2)
  | Iff(p,q) ->
        let pos1,neg1 = occurrences x p
        and pos2,neg2 = occurrences x q in
        if pos1 or pos2 or neg1 or neg2 then (true,true)
        else (false,false)
  | _ -> (false,false);;

(* ------------------------------------------------------------------------- *)
(* Disjunctive normal form (DNF) via truth tables.                           *)
(* ------------------------------------------------------------------------- *)

let list_conj l =
  if l = [] then True else end_itlist (fun p q -> And(p,q)) l;;

let list_disj l =
  if l = [] then False else end_itlist (fun p q -> Or(p,q)) l;;

let mk_lits pvs v =
  list_conj (map (fun p -> if eval p v then p else Not p) pvs);;

let rec allsatvaluations subfn v pvs =
  match pvs with
    [] -> if subfn v then [v] else []
  | p::ps -> let v' t q = if q = p then t else v(q) in
             allsatvaluations subfn (v' false) ps @
             allsatvaluations subfn (v' true) ps;;

let dnf fm =
  let pvs = atoms fm in
  let satvals = allsatvaluations (eval fm) (fun s -> false) pvs in
  list_disj (map (mk_lits (map (fun p -> Atom p) pvs)) satvals);;

(* ------------------------------------------------------------------------- *)
(* Examples.                                                                 *)
(* ------------------------------------------------------------------------- *)

(*
let fm = <<(p /\ (q \/ r /\ s)) /\ (~p \/ ~q \/ ~s)>>;;

dnf fm;;

print_truthtable fm;;

let fm = <<p /\ q /\ r /\ s /\ t /\ u \/ u /\ v>>;;

dnf fm;;
*)

(* ------------------------------------------------------------------------- *)
(* DNF via distribution.                                                     *)
(* ------------------------------------------------------------------------- *)

let rec distrib fm =
  match fm with
    And(p,(Or(q,r))) -> Or(distrib(And(p,q)),distrib(And(p,r)))
  | And(Or(p,q),r) -> Or(distrib(And(p,r)),distrib(And(q,r)))
  | _ -> fm;;

let rec rawdnf fm =
  match fm with
    And(p,q) -> distrib(And(rawdnf p,rawdnf q))
  | Or(p,q) -> Or(rawdnf p,rawdnf q)
  | _ -> fm;;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)

(*
let fm = <<(p /\ (q \/ r /\ s)) /\ (~p \/ ~q \/ ~s)>>;;

rawdnf fm;;
*)

(* ------------------------------------------------------------------------- *)
(* A version using a list representation.                                    *)
(* ------------------------------------------------------------------------- *)

let distrib s1 s2 = allpairs union s1 s2;;      (** Value restriction hell!  *)
                                                (** Need it for FOL formulas *)
let rec purednf fm =
  match fm with
    And(p,q) -> distrib (purednf p) (purednf q)
  | Or(p,q) -> union (purednf p) (purednf q)
  | _ -> [[fm]];;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)

(*
purednf fm;;
*)

(* ------------------------------------------------------------------------- *)
(* Filtering out noncontradictory disjuncts only.                            *)
(* ------------------------------------------------------------------------- *)

let negative = function (Not p) -> true | _ -> false;;

let positive lit = not(negative lit);;

let negate = function (Not p) -> p | p -> Not p;;

let contradictory lits =
  let pos,neg = partition positive lits in
  intersect pos (map negate neg) <> [];;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)

(*
filter (non contradictory) (purednf fm);;
*)

(* ------------------------------------------------------------------------- *)
(* With subsumption checking, done very naively (quadratic).                 *)
(* ------------------------------------------------------------------------- *)

let subsumes s1 s2 = psubset s2 s1;;

let subsume cls =
  filter (fun cl -> not(exists (subsumes cl) cls)) cls;;

let simpdnf fm =
  if fm = False then []
  else if fm = True then [[]]
  else subsume (filter (non contradictory) (purednf(nnf fm)));;

(* ------------------------------------------------------------------------- *)
(* Mapping back to a formula.                                                *)
(* ------------------------------------------------------------------------- *)

let dnf fm = list_disj(map list_conj (simpdnf fm));;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)

(*
dnf fm;;

tautology(Iff(fm,dnf fm));;
*)

(* ------------------------------------------------------------------------- *)
(* Conjunctive normal form (CNF) by duality.                                 *)
(* ------------------------------------------------------------------------- *)

let purecnf fm = smap (smap negate) (purednf(nnf(Not fm)));;

let simpcnf fm = subsume (filter (non contradictory) (purecnf fm));;

let cnf fm = list_conj(map list_disj (simpcnf fm));;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)

(*
cnf fm;;

cnf(Iff(fm,cnf fm));;
*)
(* ========================================================================= *)
(* Some propositional formulas to test, and functions to generate classes.   *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

(*

forall tautology
 [<<p ==> q <=> ~q ==> ~p>>;
  <<~ ~p <=> p>>;
  <<~(p ==> q) ==> q ==> p>>;
  <<~p ==> q <=> ~q ==> p>>;
  <<(p \/ q ==> p \/ r) ==> p \/ (q ==> r)>>;
  <<p \/ ~p>>;
  <<p \/ ~ ~ ~p>>;
  <<((p ==> q) ==> p) ==> p>>;
  <<(p \/ q) /\ (~p \/ q) /\ (p \/ ~q) ==> ~(~q \/ ~q)>>;
  <<(q ==> r) /\ (r ==> p /\ q) /\ (p ==> q /\ r) ==> (p <=> q)>>;
  <<p <=> p>>;
  <<((p <=> q) <=> r) <=> (p <=> (q <=> r))>>;
  <<p \/ q /\ r <=> (p \/ q) /\ (p \/ r)>>;
  <<(p <=> q) <=> (q \/ ~p) /\ (~q \/ p)>>;
  <<p ==> q <=> ~p \/ q>>;
  <<(p ==> q) \/ (q ==> p)>>;
  <<p /\ (q ==> r) ==> s <=> (~p \/ q \/ s) /\ (~p \/ ~r \/ s)>>];;

(* ------------------------------------------------------------------------- *)
(* Some graph-colouring examples.                                            *)
(* ------------------------------------------------------------------------- *)

let fm =
 <<(a1 \/ a2 \/ a3) /\
   (b1 \/ b2 \/ b3) /\
   (c1 \/ c2 \/ c3) /\
   (d1 \/ d2 \/ d3) /\
   ~(a1 /\ a2) /\ ~(a1 /\ a3) /\ ~(a2 /\ a3) /\
   ~(b1 /\ b2) /\ ~(b1 /\ b3) /\ ~(b2 /\ b3) /\
   ~(c1 /\ c2) /\ ~(c1 /\ c3) /\ ~(c2 /\ c3) /\
   ~(d1 /\ d2) /\ ~(d1 /\ d3) /\ ~(d2 /\ d3) /\
   ~(a1 /\ d1) /\ ~(a2 /\ d2) /\ ~(a3 /\ d3) /\
   ~(b1 /\ d1) /\ ~(b2 /\ d2) /\ ~(b3 /\ d3) /\
   ~(c1 /\ d1) /\ ~(c2 /\ d2) /\ ~(c3 /\ d3) /\
   ~(a1 /\ b1) /\ ~(a2 /\ b2) /\ ~(a3 /\ b3) /\
   ~(b1 /\ c1) /\ ~(b2 /\ c2) /\ ~(b3 /\ c3) /\
   ~(c1 /\ a1) /\ ~(c2 /\ a2) /\ ~(c3 /\ a3)>>;;

satisfiable fm;;

*)

(* ------------------------------------------------------------------------- *)
(* Generate assertion equivalent to R(s,t) <= n for the Ramsey number R(s,t) *)
(* ------------------------------------------------------------------------- *)

let var l =
  match l with
    [m;n] -> Atom(P("p_"^(string_of_int m)^"_"^(string_of_int n)))
  | _ -> failwith "var: expected 2-element list";;

let ramsey s t n =
  let vertices = 1 -- n in
  let yesgrps = map (allsets 2) (allsets s vertices)
  and nogrps = map (allsets 2) (allsets t vertices) in
  Or(list_disj (map (list_conj ** map var) yesgrps),
     list_disj (map (list_conj ** map (fun p -> Not(var p))) nogrps));;

(* ------------------------------------------------------------------------- *)
(* Some currently tractable examples.                                        *)
(* ------------------------------------------------------------------------- *)

(*

ramsey 3 3 4;;

tautology(ramsey 3 3 5);;

tautology(ramsey 3 3 6);;

*)

(* ------------------------------------------------------------------------- *)
(* Half adder.                                                               *)
(* ------------------------------------------------------------------------- *)

let halfsum x y = Iff(x,Not y);;

let halfcarry x y = And(x,y);;

let ha x y s c = And(Iff(s,halfsum x y),Iff(c,halfcarry x y));;

(* ------------------------------------------------------------------------- *)
(* Full adder.                                                               *)
(* ------------------------------------------------------------------------- *)

let carry x y z = Or(And(x,y),And(Or(x,y),z));;

let sum x y z = halfsum (halfsum x y) z;;

let fa x y z s c = And(Iff(s,sum x y z),Iff(c,carry x y z));;

(* ------------------------------------------------------------------------- *)
(* Useful idiom.                                                             *)
(* ------------------------------------------------------------------------- *)

let conjoin f l = list_conj (map f l);;

(* ------------------------------------------------------------------------- *)
(* n-bit ripple carry adder with carry c(0) propagated in and c(n) out.      *)
(* ------------------------------------------------------------------------- *)

let ripplecarry x y c out n =
  conjoin (fun i -> fa (x i) (y i) (c i) (out i) (c(i + 1)))
          (0 -- (n - 1));;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)

let mk_index s i = Atom(P(s^"_"^(string_of_int i)));;

(*

let [x; y; out; c] = map mk_index ["X"; "Y"; "OUT"; "C"];;

ripplecarry x y c out 2;;

*)

(* ------------------------------------------------------------------------- *)
(* Special case with 0 instead of c(0).                                      *)
(* ------------------------------------------------------------------------- *)

let ripplecarry0 x y c out n =
  psimplify
   (ripplecarry x y (fun i -> if i = 0 then False else c i) out n);;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)

(*

ripplecarry0 x y c out 2;;

*)

(* ------------------------------------------------------------------------- *)
(* Carry-select adder                                                        *)
(* ------------------------------------------------------------------------- *)

let ripplecarry1 x y c out n =
  psimplify
   (ripplecarry x y (fun i -> if i = 0 then True else c i) out n);;

let mux sel in0 in1 = Or(And(Not sel,in0),And(sel,in1));;

let offset n x i = x(n + i);;

let rec carryselect x y c0 c1 s0 s1 c s n k =
  let k' = min n k in
  let fm =
    And(And(ripplecarry0 x y c0 s0 k',ripplecarry1 x y c1 s1 k'),
        And(Iff(c k',mux (c 0) (c0 k') (c1 k')),
            conjoin (fun i -> Iff(s i,mux (c 0) (s0 i) (s1 i)))
                    (0 -- (k' - 1)))) in
  if k' < k then fm else
  And(fm,carryselect
            (offset k x) (offset k y) (offset k c0) (offset k c1)
            (offset k s0) (offset k s1) (offset k c) (offset k s)
            (n - k) k);;

(* ------------------------------------------------------------------------- *)
(* Equivalence problems for carry-select vs ripple carry adders.             *)
(* ------------------------------------------------------------------------- *)

let mk_adder_test n k =
  let [x; y; c; s; c0; s0; c1; s1; c2; s2] = map mk_index
      ["x"; "y"; "c"; "s"; "c0"; "s0"; "c1"; "s1"; "c2"; "s2"] in
  Imp(And(And(carryselect x y c0 c1 s0 s1 c s n k,Not(c 0)),
          ripplecarry0 x y c2 s2 n),
      And(Iff(c n,c2 n),
          conjoin (fun i -> Iff(s i,s2 i)) (0 -- (n - 1))));;

(* ------------------------------------------------------------------------- *)
(* Ripple carry stage that separates off the final result.                   *)
(*                                                                           *)
(*       UUUUUUUUUUUUUUUUUUUU  (u)                                           *)
(*    +  VVVVVVVVVVVVVVVVVVVV  (v)                                           *)
(*                                                                           *)
(*    = WWWWWWWWWWWWWWWWWWWW   (w)                                           *)
(*    +                     Z  (z)                                           *)
(* ------------------------------------------------------------------------- *)

let rippleshift u v c z w n =
  ripplecarry0 u v (fun i -> if i = n then w(n - 1) else c(i + 1))
                   (fun i -> if i = 0 then z else w(i - 1)) n;;

(* ------------------------------------------------------------------------- *)
(* Naive multiplier based on repeated ripple carry.                          *)
(* ------------------------------------------------------------------------- *)

let mult_rip x u v out n =
  if n = 1 then And(Iff(out 0,x 0 0),Not(out 1)) else
  psimplify
   (And(Iff(out 0,x 0 0),
        And(rippleshift
               (fun i -> if i = n - 1 then False else x 0 (i + 1))
               (x 1) (v 2) (out 1) (u 2) n,
            if n = 2 then And(Iff(out 2,u 2 0),Iff(out 3,u 2 1)) else
            conjoin (fun k -> rippleshift (u k) (x k) (v(k + 1)) (out k)
                                (if k = n - 1 then fun i -> out(n + i)
                                 else u(k + 1)) n) (2 -- (n - 1)))));;

(* ------------------------------------------------------------------------- *)
(* One stage in a sequential multiplier based on CSAs.                       *)
(*                                                                           *)
(*        UUUUUUUUUUUUUUU         (u)                                        *)
(*        VVVVVVVVVVVVVVV         (v)                                        *)
(*       XXXXXXXXXXXXXXX          (x)                                        *)
(*     ------------------------------------------------                      *)
(*       WWWWWWWWWWWWWWW          (w)                                        *)
(*       YYYYYYYYYYYYYYY          (y)                                        *)
(*                      Z         (z)                                        *)
(* ------------------------------------------------------------------------- *)

let csastage u v x z w y n =
  And(ha (u 0) (v 0) z (w 0),
      And(conjoin (fun i -> fa (u(i + 1)) (v(i + 1)) (x i)
                               (y i) (w(i + 1))) (0 -- (n - 2)),
          Iff(y(n - 1),x(n - 1))));;

(* ------------------------------------------------------------------------- *)
(* CSA-based multiplier only using ripple carry for last iteration.          *)
(* ------------------------------------------------------------------------- *)

let csa_init x u v out n =
  And(Iff(out 0,x 0 0),
      And(ha (x 0 1) (x 1 0) (out 1) (u 3 0),
          And(conjoin (fun i -> fa (x 0 (i + 2)) (x 1 (i + 1)) (x 2 i)
                                   (v 3 i) (u 3 (i + 1)))
                      (0 -- (n - 3)),
              And(ha (x 1 (n - 1)) (x 2 (n - 2))
                     (v 3 (n - 2)) (u 3 (n - 1)),
                  Iff(v 3 (n - 1),x 2 (n - 1))))));;

let mult_csa x u v out n =
  And(csa_init x u v out n,
      And(conjoin (fun i -> csastage (u i) (v i) (x i)
                                 (out(i - 1)) (u(i + 1)) (v(i + 1)) n)
                  (3 -- (n - 1)),
          rippleshift (u n) (v n) (v(n + 1))
                      (out(n - 1)) (fun i -> out(n + i)) n));;

(* ------------------------------------------------------------------------- *)
(* Equivalence for ripple vs CSA sequential multipler.                       *)
(* ------------------------------------------------------------------------- *)

let mk_index2 x i j =
  Atom(P(x^"_"^(string_of_int i)^"_"^(string_of_int j)));;

let mk_mult_test n =
  let [x; y; out; out'] = map mk_index ["x"; "y"; "**"; "p"] in
  let m i j = And(x i,y j)
  and [u; v; u'; v'] = map mk_index2 ["u"; "v"; "w"; "z"] in
  Imp(And(mult_rip m u v out n,
          mult_csa m u' v' out' n),
      conjoin (fun i -> Iff(out i,out' i)) (0 -- (n - 1)));;

(* ------------------------------------------------------------------------- *)
(* Primality examples (using CSAs since they are slightly shorter).          *)
(* For large examples, should use "num" instead of "int" in these functions. *)
(* ------------------------------------------------------------------------- *)

let rec bitlength x = if x = 0 then 0 else 1 + bitlength (x / 2);;

let rec bit n x = if n = 0 then x mod 2 = 1 else bit (n - 1) (x / 2);;

let congruent_to x m n =
  conjoin (fun i -> if bit i m then x i else Not(x i))
          (0 -- (n - 1));;

let prime p =
  let [x; y; out] = map mk_index ["x"; "y"; "out"] in
  let m i j = And(x i,y j)
  and [u; v] = map mk_index2 ["u"; "v"] in
  let n = bitlength p in
  Not(And(mult_rip m u v out (n - 1),
      congruent_to out p (max n (2 * n - 2))));;
(* ========================================================================= *)
(* Definitional CNF.                                                         *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

(* ------------------------------------------------------------------------- *)
(* Variant of NNF without splitting equivalences.                            *)
(* ------------------------------------------------------------------------- *)

let rec nenf fm =
  match fm with
    Not(Not p) -> nenf p
  | Not(And(p,q)) -> Or(nenf(Not p),nenf(Not q))
  | Not(Or(p,q)) -> And(nenf(Not p),nenf(Not q))
  | Not(Imp(p,q)) -> And(nenf p,nenf(Not q))
  | Not(Iff(p,q)) -> Iff(nenf p,nenf(Not q))
  | And(p,q) -> And(nenf p,nenf q)
  | Or(p,q) -> Or(nenf p,nenf q)
  | Imp(p,q) -> Or(nenf(Not p),nenf q)
  | Iff(p,q) -> Iff(nenf p,nenf q)
  | _ -> fm;;

(* ------------------------------------------------------------------------- *)
(* Make a stylized variable and update the index.                            *)
(* ------------------------------------------------------------------------- *)

let mkprop n = Atom(P("p_"^(string_of_num n))),n +/ Int 1;;

(* ------------------------------------------------------------------------- *)
(* Make n large enough that "v_m" won't clash with s for any m >= n          *)
(* ------------------------------------------------------------------------- *)

let max_varindex pfx =
  let m = String.length pfx in
  fun s n ->
    let l = String.length s in
    if l <= m or String.sub s 0 m <> pfx then n else
    let s' = String.sub s m (l - m) in
    if forall numeric (explode s') then max_num n (num_of_string s')
    else n;;

(* ------------------------------------------------------------------------- *)
(* Basic definitional CNF procedure.                                         *)
(* ------------------------------------------------------------------------- *)

let rec maincnf (fm,defs,n as trip) =
  match fm with
    And(p,q) -> defstep (fun (p,q) -> And(p,q)) (p,q) trip
  | Or(p,q) -> defstep (fun (p,q) -> Or(p,q)) (p,q) trip
  | Iff(p,q) -> defstep (fun (p,q) -> Iff(p,q)) (p,q) trip
  | _ -> trip

and defstep op (p,q) (fm,defs,n) =
  let fm1,defs1,n1 = maincnf (p,defs,n) in
  let fm2,defs2,n2 = maincnf (q,defs1,n1) in
  let fm' = op(fm1,fm2) in
  try (fst(apply defs2 fm'),defs2,n2) with Failure _ ->
  let v,n3 = mkprop n2 in (v,(fm'|->(v,Iff(v,fm'))) defs2,n3);;

let defcnf fm =
  let fm' = nenf(psimplify fm) in
  let n = Int 1 +/ overatoms (max_varindex "p_" ** pname) fm' (Int 0) in
  let (fm'',defs,_) = maincnf (fm',undefined,n) in
  let deflist = map (snd ** snd) (funset defs) in
  let subcnfs = itlist ((@) ** simpcnf) deflist (simpcnf fm'') in
  list_conj (map list_disj (setify subcnfs));;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)

(*
let fm = <<(p \/ (q /\ ~r)) /\ s>>;;

defcnf fm;;

cnf fm;;

cnf <<p <=> (q <=> r)>>;;
*)

(* ------------------------------------------------------------------------- *)
(* Version tweaked to exploit initial structure.                             *)
(* ------------------------------------------------------------------------- *)

let subcnf sfn op (p,q) (fm,defs,n) =
  let fm1,defs1,n1 = sfn(p,defs,n) in
  let fm2,defs2,n2 = sfn(q,defs1,n1) in (op(fm1,fm2),defs2,n2);;

let rec orcnf (fm,defs,n as trip) =
  match fm with
    Or(p,q) -> subcnf orcnf (fun (p,q) -> Or(p,q)) (p,q) trip
  | _ -> maincnf trip;;

let rec andcnf (fm,defs,n as trip) =
  match fm with
    And(p,q) -> subcnf andcnf (fun (p,q) -> And(p,q)) (p,q) trip
  | _ -> orcnf trip;;

let defcnfs fm =
  let fm' = nenf(psimplify fm) in
  let n = Int 1 +/ overatoms (max_varindex "p_" ** pname) fm' (Int 0) in
  let (fm'',defs,_) = andcnf (fm',undefined,n) in
  let deflist = map (snd ** snd) (funset defs) in
  setify(itlist ((@) ** simpcnf) deflist (simpcnf fm''));;

let defcnf fm = list_conj (map list_disj (defcnfs fm));;

(* ------------------------------------------------------------------------- *)
(* Examples.                                                                 *)
(* ------------------------------------------------------------------------- *)

(*
defcnf fm;;
*)

(* ------------------------------------------------------------------------- *)
(* Version using only implication where possible.                            *)
(* ------------------------------------------------------------------------- *)

let defstep pos sfn op (p,q) (fm,defs,n) =
  let fm1,defs1,n1 = sfn (p,defs,n) in
  let fm2,defs2,n2 = sfn (q,defs1,n1) in
  let (fl,fm' as ffm') = (pos,op(fm1,fm2)) in
  try (fst(apply defs2 ffm'),defs2,n2) with Failure _ ->
  let (v,n3) = mkprop n2 in
  let cons = if pos then fun (p,q) -> Imp(p,q)
             else fun (p,q) -> Iff(p,q) in
  (v,(ffm' |-> (v,cons(v,fm'))) defs2,n3);;

let rec maincnf pos (fm,defs,n as trip) =
  match fm with
    And(p,q) ->
        defstep pos (maincnf pos) (fun (p,q) -> And(p,q)) (p,q) trip
  | Or(p,q) ->
        defstep pos (maincnf pos) (fun (p,q) -> Or(p,q)) (p,q) trip
  | Iff(p,q) ->
        defstep pos (maincnf false) (fun (p,q) -> Iff(p,q)) (p,q) trip
  | _ -> trip;;

let rec orcnf pos (fm,defs,n as trip) =
  match fm with
    Or(p,q) -> subcnf (orcnf pos) (fun (p,q) -> Or(p,q)) (p,q) trip
  | _ -> maincnf pos trip;;

let rec andcnf pos (fm,defs,n as trip) =
  match fm with
    And(p,q) -> subcnf (andcnf pos) (fun (p,q) -> And(p,q)) (p,q) trip
  | _ -> orcnf pos trip;;

let defcnfs imps fm =
  let fm' = nenf(psimplify fm) in
  let n = Int 1 +/ overatoms (max_varindex "p_" ** pname) fm' (Int 0) in
  let (fm'',defs,_) = andcnf imps (fm',undefined,n) in
  let deflist = map (snd ** snd) (funset defs) in
  setify(itlist ((@) ** simpcnf) deflist (simpcnf fm''));;

let defcnf imps fm = list_conj (map list_disj (defcnfs imps fm));;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)

(*
defcnf false fm;;

defcnf true fm;;
*)

(* ------------------------------------------------------------------------- *)
(* Version that guarantees 3-CNF.                                            *)
(* ------------------------------------------------------------------------- *)

let rec andcnf3 pos (fm,defs,n as trip) =
  match fm with
    And(p,q) -> subcnf (andcnf3 pos) (fun (p,q) -> And(p,q)) (p,q) trip
  | _ -> maincnf pos trip;;

let defcnf3s imps fm =
  let fm' = nenf(psimplify fm) in
  let n = Int 1 +/ overatoms (max_varindex "p_" ** pname) fm' (Int 0) in
  let (fm'',defs,_) = andcnf3 imps (fm',undefined,n) in
  let deflist = map (snd ** snd) (funset defs) in
  setify(itlist ((@) ** simpcnf) deflist (simpcnf fm''));;

let defcnf3 imps fm = list_conj (map list_disj (defcnf3s imps fm));;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)

(*
let fm = <<(p \/ q \/ r \/ s) /\ ~p /\ (~p \/ q)>>;;

defcnf true fm;;

defcnf3 true fm;;
*)
(* ========================================================================= *)
(* The Davis-Putnam and Davis-Putnam-Loveland-Logemann procedures.           *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

let clausal fm = defcnfs false fm;;

(* ------------------------------------------------------------------------- *)
(* Examples of clausal form.                                                 *)
(* ------------------------------------------------------------------------- *)

(*
let fm = <<p /\ (q <=> (~p <=> r))>>;;

clausal fm;;
*)

(* ------------------------------------------------------------------------- *)
(* The DP procedure.                                                         *)
(* ------------------------------------------------------------------------- *)

let one_literal_rule clauses =
  let ucl = hd (find (fun cl -> length cl = 1) clauses) in
  let ucl' = negate ucl in
  let clauses1 = filter (fun cl -> not (mem ucl cl)) clauses in
  smap (fun cl -> subtract cl [ucl']) clauses1;;

let affirmative_negative_rule clauses =
  let literals = itlist union clauses [] in
  let neglits,poslits = partition negative literals in
  let neglits' = smap negate neglits in
  let common = intersect poslits neglits' in
  let pos_only_lits = subtract poslits common
  and neg_only_lits = subtract neglits' common in
  let elim = union pos_only_lits (smap negate neg_only_lits) in
  if elim = [] then failwith "affirmative_negative_rule" else
  filter (fun cl -> intersect cl elim = []) clauses;;

let find_blowup cls l =
  let m = length(filter (mem l) cls)
  and n = length(filter (mem (negate l)) cls) in
  m * n - m - n,l;;

let resolution_rule clauses =
  let pvs = filter positive (itlist union clauses []) in
  let lblows = map (find_blowup clauses) pvs in
  let p = assoc (end_itlist min (map fst lblows)) lblows in
  let p' = negate p in
  let pos,notpos = partition (mem p) clauses in
  let neg,none = partition (mem p') notpos in
  let pos' = smap (filter (fun l -> l <> p)) pos
  and neg' = smap (filter (fun l -> l <> p')) neg in
  let res0 = allpairs union pos' neg' in
  union none (filter (non contradictory) res0);;

(* ------------------------------------------------------------------------- *)
(* Overall procedure.                                                        *)
(* ------------------------------------------------------------------------- *)

let rec dp clauses =
  if clauses = [] then true
  else if mem [] clauses then false else
  try dp(one_literal_rule clauses)
  with Failure _ -> try
      dp(affirmative_negative_rule clauses)
  with Failure _ ->
      dp(resolution_rule clauses);;

(* ------------------------------------------------------------------------- *)
(* Davis-Putnam satisfiability tester and tautology checker.                 *)
(* ------------------------------------------------------------------------- *)

let dpsat fm = dp(clausal fm);;

let dptaut fm = not(dpsat(Not fm));;

(* ------------------------------------------------------------------------- *)
(* Examples.                                                                 *)
(* ------------------------------------------------------------------------- *)

(*
tautology(prime 11);;

dptaut(prime 11);;
*)

(* ------------------------------------------------------------------------- *)
(* The same thing but with the DPLL procedure.                               *)
(* ------------------------------------------------------------------------- *)

let find_count cls l =
  let m = length(filter (mem l) cls)
  and n = length(filter (mem (negate l)) cls) in
  m + n,l;;

let rec dpll clauses =
  if clauses = [] then true
  else if mem [] clauses then false else
  try dpll(one_literal_rule clauses)
  with Failure _ -> try
      dpll(affirmative_negative_rule clauses)
  with Failure _ ->
    let pvs = filter positive (itlist union clauses []) in
    let lcounts = map (find_count clauses) pvs in
    let p = assoc (end_itlist max (map fst lcounts)) lcounts in
    dpll (insert [p] clauses) or
    dpll (insert [negate p] clauses);;

let dpllsat fm = dpll(clausal fm);;

let dplltaut fm = not(dpllsat(Not fm));;

(* ------------------------------------------------------------------------- *)
(* The same example.                                                         *)
(* ------------------------------------------------------------------------- *)

dplltaut(prime 11);;
(* ========================================================================= *)
(* Simple implementation of Stalmarck's algorithm.                           *)
(*                                                                           *)
(* NB! This algorithm is patented for commercial use (not that a toy version *)
(* like this would actually be useful in practice). See US patent 5 276 897, *)
(* Swedish patent 467 076 and European patent 0403 454 for example.          *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

(* ------------------------------------------------------------------------- *)
(* Triplet transformation, using functions defined earlier.                  *)
(* ------------------------------------------------------------------------- *)

let triplicate fm =
  let fm' = nenf(psimplify fm) in
  let n = Int 1 +/ overatoms (max_varindex "p_" ** pname) fm' (Int 0) in
  let (p,defs,_) = maincnf false (fm',undefined,n) in
  p,map (snd ** snd) (funset defs);;

(* ------------------------------------------------------------------------- *)
(* Automatically generate triggering rules to save writing them out.         *)
(* ------------------------------------------------------------------------- *)

let atom lit = if negative lit then negate lit else lit;;

let rec align (p,q) =
  if atom p < atom q then align(q,p) else
  if negative p then (negate p,negate q) else (p,q);;

let equate2 (p,q) eqv = equate (negate p,negate q) (equate (p,q) eqv);;

let rec irredundant rel eqs =
  match eqs with
    [] -> []
  | (p,q)::oth ->
      if canonize rel p = canonize rel q then irredundant rel oth
      else insert (p,q) (irredundant (equate2 (p,q) rel) oth);;

let consequences peq fm eqs =
  let pq = (fun (p,q) -> Iff(p,q)) peq in
  let raw = filter
    (fun (r,s) -> tautology(Imp(And(pq,fm),Iff(r,s)))) eqs in
  irredundant (equate2 peq unequal) raw;;

let triggers fm =
  let poslits = insert True (map (fun p -> Atom p) (atoms fm)) in
  let lits = union poslits (map (fun p -> Not p) poslits) in
  let pairs = allpairs (fun p q -> p,q) lits lits in
  let npairs = filter (fun (p,q) -> atom p <> atom q) pairs in
  let eqs = setify(map align npairs) in
  let raw = map (fun p -> p,consequences p fm eqs) eqs in
  filter (fun (p,c) -> c <> []) raw;;

(* ------------------------------------------------------------------------- *)
(* An example.                                                               *)
(* ------------------------------------------------------------------------- *)

(*
triggers <<p <=> (q /\ r)>>;;
*)

(* ------------------------------------------------------------------------- *)
(* Precompute and instantiate triggers for standard triplets.                *)
(* ------------------------------------------------------------------------- *)

let ddnegate fm = match fm with Not(Not p) -> p | _ -> fm;;

let trigger =
  let [trig_and; trig_or; trig_imp; trig_iff] =
    map triggers
      [<<p <=> q /\ r>>; <<p <=> q \/ r>>;
       <<p <=> (q ==> r)>>; <<p <=> (q <=> r)>>]
  and p = <<p>> and q = <<q>> and r = <<r>> in
  let inst_fn [x;y;z] =
    let subfn = fpf [P"p" |-> x; P"q" |-> y; P"r" |-> z] in
    ddnegate ** propsubst subfn in
  let inst2_fn i (p,q) = align(inst_fn i p,inst_fn i q) in
  let instn_fn i (a,c) = inst2_fn i a,map (inst2_fn i) c in
  let inst_trigger = map ** instn_fn in
  function (Iff(x,And(y,z))) -> inst_trigger [x;y;z] trig_and
         | (Iff(x,Or(y,z))) -> inst_trigger [x;y;z] trig_or
         | (Iff(x,Imp(y,z))) -> inst_trigger [x;y;z] trig_imp
         | (Iff(x,Iff(y,z))) -> inst_trigger [x;y;z] trig_iff;;

(* ------------------------------------------------------------------------- *)
(* Finding variables in triggers, prioritized by fecundity estimate.         *)
(* ------------------------------------------------------------------------- *)

let rec fecundity ((p,q),conseqs) f =
  let vars = union (atoms p) (atoms q) in
  let n = (if atom q = True then 2 else 1) * length conseqs in
  itlist (fun x -> (x |-> n + tryapplyd f x 0)) vars f;;

let variable_list trigs =
  let fec = itlist fecundity trigs undefined in
  let repcounts = funset fec in
  map (fun (p,q) -> Atom p)
      (sort (fun (_,n) (_,m) -> m <= n) repcounts);;

(* ------------------------------------------------------------------------- *)
(* Compute a function mapping each variable/true to relevant triggers.       *)
(* ------------------------------------------------------------------------- *)

let insert_relevant p trg f = (p |-> insert trg (tryapplyl f p)) f;;

let insert_relevant2 ((p,q),_ as trg) f =
  insert_relevant p trg (insert_relevant q trg f);;

let relevance trigs =
  let vars = variable_list trigs
  and rfn = itlist insert_relevant2 trigs undefined in
  vars,rfn;;

(* ------------------------------------------------------------------------- *)
(* Merging of equiv classes and relevancies.                                 *)
(* ------------------------------------------------------------------------- *)

let equatecons (p0,q0) (eqv,rfn as erf) =
  let p = canonize eqv p0
  and q = canonize eqv q0 in
  if p = q then [],erf else
  let p' = canonize eqv (negate p0)
  and q' = canonize eqv (negate q0) in
  let eqv' = equate2(p,q) eqv
  and sp_pos = tryapplyl rfn p
  and sp_neg = tryapplyl rfn p'
  and sq_pos = tryapplyl rfn q
  and sq_neg = tryapplyl rfn q' in
  let rfn' = itlist identity
    [canonize eqv' p |-> union sp_pos sq_pos;
     canonize eqv' p' |-> union sp_neg sq_neg] rfn in
  let nw = union (intersect sp_pos sq_pos) (intersect sp_neg sq_neg) in
  itlist (union ** snd) nw [],(eqv',rfn');;

(* ------------------------------------------------------------------------- *)
(* Zero-saturation given an equivalence/relevance and new assignments.       *)
(* ------------------------------------------------------------------------- *)

let rec zero_saturate erf assigs =
  match assigs with
    [] -> erf
  | (p,q)::ts ->
      let news,erf' = equatecons (p,q) erf in
      zero_saturate erf' (union ts news);;

(* ------------------------------------------------------------------------- *)
(* Zero-saturate then check for contradictoriness.                           *)
(* ------------------------------------------------------------------------- *)

let contraeq pfn =
  let vars = filter positive (equated pfn) in
  exists (fun x -> canonize pfn x = canonize pfn (Not x)) vars;;

let zero_saturate_and_check erf trigs =
  let (eqv',rfn' as erf') = zero_saturate erf trigs in
  if contraeq eqv' then snd(equatecons (True,Not True) erf') else erf';;

(* ------------------------------------------------------------------------- *)
(* Iterated equivalening over a set.                                         *)
(* ------------------------------------------------------------------------- *)

let rec equateset s0 eqfn =
  match s0 with
    a::(b::s2 as s1) ->
      equateset s1 (snd(equatecons (a,b) eqfn))
  | _ -> eqfn;;

(* ------------------------------------------------------------------------- *)
(* Intersection operation on equivalence classes and relevancies.            *)
(* ------------------------------------------------------------------------- *)

let rec inter els (eq1,_ as erf1) (eq2,_ as erf2) rev1 rev2 erf =
  match els with
    [] -> erf
  | x::xs ->
      let (b1,n1) = tryterminus eq1 x
      and (b2,n2) = tryterminus eq2 x in
      let s1 = apply rev1 b1 and s2 = apply rev2 b2 in
      let s = intersect s1 s2 in
      inter (subtract xs s) erf1 erf2 rev1 rev2
            (if s = [x] then erf else equateset s erf);;

(* ------------------------------------------------------------------------- *)
(* Reverse the equivalence mappings.                                         *)
(* ------------------------------------------------------------------------- *)

let reverseq domain eqv =
  let al = map (fun x -> x,canonize eqv x) domain in
  itlist (fun (y,x) f -> (x |-> insert y (tryapplyl f x)) f)
         al undefined;;

(* ------------------------------------------------------------------------- *)
(* Special intersection taking contradictoriness into account.               *)
(* ------------------------------------------------------------------------- *)

let truefalse pfn = canonize pfn (Not True) = canonize pfn True;;

let stal_intersect (eq1,_ as erf1) (eq2,_ as erf2) erf =
  if truefalse eq1 then erf2
  else if truefalse eq2 then erf1 else
  let dom1 = equated eq1 and dom2 = equated eq2 in
  let comdom = intersect dom1 dom2 in
  let rev1 = reverseq dom1 eq1 and rev2 = reverseq dom2 eq2 in
  inter comdom erf1 erf2 rev1 rev2 erf;;

(* ------------------------------------------------------------------------- *)
(* General n-saturation for n >= 1                                           *)
(* ------------------------------------------------------------------------- *)

let saturate allvars =
  let rec saturate n erf assigs =
    let (eqv',_ as erf') = zero_saturate_and_check erf assigs in
    if n = 0 or truefalse eqv' then erf' else
    let (eqv'',_ as erf'') = splits n erf' allvars in
    if eqv'' = eqv' then erf''
    else saturate n erf'' []
  and splits n (eqv,_ as erf) vars =
    match vars with
      [] -> erf
    | p::ovars ->
          if canonize eqv p <> p then splits n erf ovars else
          let erf0 = saturate (n - 1) erf [p,Not True]
          and erf1 = saturate (n - 1) erf [p,True] in
          let (eqv',_ as erf') = stal_intersect erf0 erf1 erf in
          if truefalse eqv' then erf'
          else splits n erf' ovars in
  saturate;;

(* ------------------------------------------------------------------------- *)
(* Cleaning up the triggers to represent the equivalence relation.           *)
(* ------------------------------------------------------------------------- *)

let minatom fms =
  match fms with
    fm::ofms ->
      itlist (fun x y -> if atom x < atom y then x else y) ofms fm
  | _ -> failwith "minatom: empty list";;

let realcanon eqv =
  let domain = equated eqv in
  let rev = reverseq domain eqv in
  itlist (fun x -> (x |-> minatom(apply rev (canonize eqv x)))) domain
         undefined;;

let substitute eqv =
  let cfn = tryapply(realcanon eqv) in
  fun (p,q) -> align(cfn p,cfn q);;

let rec cleanup subfn trigs =
  match trigs with
    [] -> []
  | ((p,q),conseqs)::otr ->
        let (p',q') = subfn (p,q) in
        if p' = q' or p = negate q' then cleanup subfn otr else
        let conseqs' = map subfn conseqs in
        let useful,triv =
           partition (fun (p,q) -> atom p <> atom q) conseqs' in
        let news =
          if exists (fun (p,q) -> p = negate q) triv
          then (p',q'),[align(True,Not True)]
          else (p',q'),useful in
        insert news (cleanup subfn otr);;

(* ------------------------------------------------------------------------- *)
(* Saturate up to a limit.                                                   *)
(* ------------------------------------------------------------------------- *)

let rec saturate_upto n m trigs assigs =
  if n > m then
   (print_string("%%% Too deep for "^(string_of_int m)^"-saturation");
    print_newline();
    false)
  else
   (print_string("*** Starting "^(string_of_int n)^"-saturation");
    print_newline();
    let vars,rfn = relevance trigs in
    let (eqv',rfn') = saturate vars n (unequal,rfn) assigs in
    if truefalse eqv' then true else
    let trigs' = cleanup (substitute eqv') trigs in
    saturate_upto (n + 1) m trigs' []);;

(* ------------------------------------------------------------------------- *)
(* Overall function.                                                         *)
(* ------------------------------------------------------------------------- *)

let include_trigger (eq,conseqs) f =
  (eq |-> union conseqs (tryapplyl f eq)) f;;

let stalmarck fm =
  let fm' = psimplify(Not fm) in
  if fm' = False then true else if fm' = True then false else
  let p,triplets = triplicate fm' in
  let trigfn = itlist (itlist include_trigger ** trigger)
                      triplets undefined in
  saturate_upto 0 2 (funset trigfn) [p,True];;

(* ------------------------------------------------------------------------- *)
(* Try the primality examples.                                               *)
(* ------------------------------------------------------------------------- *)

(*

do_list (time stalmarck)
  [prime 5;
   prime 13;
   prime 23;
   prime 43;
   prime 97];;

*)

(* ------------------------------------------------------------------------- *)
(* Artifical example of Urquhart formulas.                                   *)
(* ------------------------------------------------------------------------- *)

let urquhart n =
  let pvs = map (fun n -> Atom(P("p_"^(string_of_int n)))) (1 -- n) in
  end_itlist (fun p q -> Iff(p,q)) (pvs @ pvs);;

(*
map (time stalmarck ** urquhart) [1;2;4;8;16];;

*)
(* ========================================================================= *)
(* Binary decision diagrams (BDDs) using complement edges.                   *)
(*                                                                           *)
(* In practice one would use hash tables, but we use abstract finite         *)
(* partial functions here. They might also look nicer imperatively.          *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

type bddnode = prop * int * int;;

(* ------------------------------------------------------------------------- *)
(* A BDD contains a variable order, unique and computed table.               *)
(* ------------------------------------------------------------------------- *)

type bdd = Bdd of ((bddnode,int)func * (int,bddnode)func * int) *
                  (prop->prop->bool);;

let print_bdd (Bdd((unique,uback,n),ord)) =
  print_string ("<BDD with "^(string_of_int n)^" nodes>");;

(*
#install_printer print_bdd;;
*)

(* ------------------------------------------------------------------------- *)
(* Map a BDD node back to its components.                                    *)
(* ------------------------------------------------------------------------- *)

let expand_node =
  let expand_pos (Bdd((unique,expand,_),_)) n =
    tryapplyd expand n (P"",1,1) in
  fun bdd n ->
    if n < 0 then let (s,l,r) = expand_pos bdd (-n) in (s,-l,-r)
    else expand_pos bdd n;;

(* ------------------------------------------------------------------------- *)
(* Lookup or insertion if not there in unique table.                         *)
(* ------------------------------------------------------------------------- *)

let lookup_unique (Bdd((unique,expand,n),ord) as bdd) node =
  try bdd,apply unique node with Failure _ ->
  Bdd(((node|->n) unique,(n|->node) expand,n+1),ord),n;;

(* ------------------------------------------------------------------------- *)
(* Produce a BDD node (old or new).                                          *)
(* ------------------------------------------------------------------------- *)

let mk_node bdd (s,l,r) =
  if l = r then bdd,l
  else if l < 0 then
    let bdd',n = lookup_unique bdd (s,-l,-r) in bdd',-n
  else lookup_unique bdd (s,l,r);;

(* ------------------------------------------------------------------------- *)
(* Create a new BDD with a given ordering.                                   *)
(* ------------------------------------------------------------------------- *)

let mk_bdd ord = Bdd((undefined,undefined,2),ord);;

(* ------------------------------------------------------------------------- *)
(* Extract the ordering field of a BDD.                                      *)
(* ------------------------------------------------------------------------- *)

let order (Bdd(_,ord)) =
  fun s1 s2 -> (s2 = P"" & s1 <> P"") or ord s1 s2;;

(* ------------------------------------------------------------------------- *)
(* Perform an AND operation on BDDs, maintaining canonicity.                 *)
(* ------------------------------------------------------------------------- *)

let rec bdd_and (bdd,comp as bddcomp) (m1,m2) =
  if m1 = -1 or m2 = -1 then bddcomp,-1
  else if m1 = 1 then bddcomp,m2 else if m2 = 1 then bddcomp,m1 else
  try bddcomp,apply comp (m1,m2) with Failure _ ->
  try  bddcomp,apply comp (m2,m1) with Failure _ ->
  let (s1,l1,r1) = expand_node bdd m1
  and (s2,l2,r2) = expand_node bdd m2 in
  let (s,lpair,rpair) =
      if s1 = s2 then s1,(l1,l2),(r1,r2)
      else if order bdd s1 s2 then s1,(l1,m2),(r1,m2)
      else s2,(m1,l2),(m1,r2) in
  let bddcomp1,lnew = bdd_and bddcomp lpair in
  let (bdd2,comp2),rnew = bdd_and bddcomp1 rpair in
  let bdd',n = mk_node bdd2 (s,lnew,rnew) in
  let comp' = ((m1,m2) |-> n) comp2 in (bdd',comp'),n;;

(* ------------------------------------------------------------------------- *)
(* Main formula to BDD conversion, with a store of previous subnodes.        *)
(* ------------------------------------------------------------------------- *)

let bddify subbdds =
  let rec mkbdd (bdd,comp as bddcomp) fm =
    match fm with
      False -> bddcomp,-1
    | True -> bddcomp,1
    | Atom(s) ->
       (try bddcomp,assoc s subbdds with Failure _ ->
        let bdd',n = mk_node bdd (s,1,-1) in (bdd',comp),n)
    | Not(p) -> let bdd1,n = mkbdd bddcomp p in bdd1,-n
    | And(l,r) -> let bddl,nl = mkbdd bddcomp l in
                  let bddr,nr = mkbdd bddl r in
                  bdd_and bddr (nl,nr)
    | Or(l,r) ->  mkbdd bddcomp (Not(And(Not l,Not r)))
    | Imp(l,r) -> mkbdd bddcomp (Not(And(l,Not r)))
    | Iff(l,r) -> mkbdd bddcomp (Not(And(Not(And(l,r)),
                                     Not(And(Not l,Not r))))) in
  mkbdd;;

(* ------------------------------------------------------------------------- *)
(* Test.                                                                     *)
(* ------------------------------------------------------------------------- *)

let bddtaut fm =
  let bdd = mk_bdd (fun s1 s2 -> s1 < s2) in
  snd(bddify [] (bdd,undefined) fm) = 1;;

(*
bddtaut (ramsey 3 3 6);;

bddtaut (prime 17);;

bddtaut (mk_adder_test 4 2);;
*)

(* ------------------------------------------------------------------------- *)
(* Towards a more intelligent treatment of "definitions".                    *)
(* ------------------------------------------------------------------------- *)

let rec conjuncts fm acc =
  match fm with
   And(p,q) -> conjuncts p (conjuncts q acc)
  | _ -> insert fm acc;;

let dest_nimp fm =
  match fm with
    Imp(l,r) -> l,r
  | Not(p) -> p,False
  | _ -> failwith "dest_nimp: not an implication or negation";;

let rec dest_def fm =
  match fm with
    Iff(Atom(p),r) -> p,r
  | Iff(r,Atom(p)) -> p,r
  | _ -> failwith "not a defining equivalence";;

let restore_eqs defs fm =
  itlist (fun (p,fm) r -> Imp(Iff(Atom(p),fm),r)) defs fm;;

let rec sort_defs acc defs fm =
  if defs = [] then rev acc,fm else
  try let (p,q) = find
        (fun (p,q) -> let fvs = atoms q in
             not (exists (fun (p',_) -> mem p' fvs) defs)) defs in
      let ps,nonps = partition (fun (p',_) -> p' = p) defs in
      let ps' = subtract ps [p,q] in
      sort_defs ((p,q)::acc) nonps (restore_eqs  ps' fm)
  with Failure _ ->
      [],restore_eqs defs fm;;

(* ------------------------------------------------------------------------- *)
(* Also attempt to discover a more "topological" variable ordering.          *)
(* ------------------------------------------------------------------------- *)

let sinsert x s = if mem x s then s else x::s;;

let rec varorder pvs defs fm =
  match defs with
    [] -> rev(itlist sinsert (atoms fm) pvs)
  | ((p,q)::odefs) ->
        let pvs' = sinsert p (itlist sinsert (atoms q) pvs) in
        varorder pvs' odefs fm;;

(* ------------------------------------------------------------------------- *)
(* Improved setup.                                                           *)
(* ------------------------------------------------------------------------- *)

let rec process bdd subbdds defs fm =
  match defs with
    [] -> bddify subbdds bdd fm
  | (p,q)::odefs ->
        let bdd',b = bddify subbdds bdd q in
        process bdd' ((p,b)::subbdds) odefs fm;;

let ebddtaut fm =
  let l,r = try dest_nimp fm with Failure _ -> True,fm in
  let eqs,noneqs = partition (can dest_def) (conjuncts l []) in
  let defs,fm' = sort_defs [] (map dest_def eqs)
                         (itlist (fun l r -> Imp(l,r)) noneqs r) in
  let dvs = map fst defs in
  let vars = filter (fun v -> not(mem v dvs)) (varorder [] defs fm') in
  let bdd = mk_bdd (earlier vars) in
  snd(process (bdd,undefined) [] defs fm') = 1;;

(* ------------------------------------------------------------------------- *)
(* Examples.                                                                 *)
(* ------------------------------------------------------------------------- *)

(*
ebddtaut (prime 101);;

ebddtaut (mk_adder_test 9 5);;

ebddtaut (mk_mult_test 7);;
*)
(* ========================================================================= *)
(* Basic stuff for first order logic.                                        *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

(* ------------------------------------------------------------------------- *)
(* Terms.                                                                    *)
(* ------------------------------------------------------------------------- *)

type term = Var of string
          | Fn of string * term list;;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)

(*
Fn("sqrt",[Fn("-",[Fn("1",[]);
                   Fn("cos",[Fn("power",[Fn("+",[Var "x"; Var "y"]);
                                        Fn("2",[])])])])]);;
*)

(* ------------------------------------------------------------------------- *)
(* Abbreviation for FOL formula.                                             *)
(* ------------------------------------------------------------------------- *)

type fol = R of string * term list;;

(* ------------------------------------------------------------------------- *)
(* Trivial example of "x + y < z".                                           *)
(* ------------------------------------------------------------------------- *)

(*
Atom(R("<",[Fn("+",[Var "x"; Var "y"]); Var "z"]));;
*)

(* ------------------------------------------------------------------------- *)
(* Parsing of terms.                                                         *)
(* ------------------------------------------------------------------------- *)

let is_const s = forall numeric (explode s) or s = "nil";;

let rec parse_atomic_term vs inp =
  match inp with
    [] -> failwith "term expected"
  | "("::rest -> parse_bracketed (parse_term vs) ")" rest
  | f::"("::")"::rest -> Fn(f,[]),rest
  | f::"("::rest ->
      papply (fun args -> Fn(f,args))
             (parse_bracketed (parse_list "," (parse_term vs)) ")" rest)
  | a::rest ->
      (if is_const a & not(mem a vs) then Fn(a,[]) else Var a),rest

and parse_term vs inp =
  parse_right_infix "::" (fun (e1,e2) -> Fn("::",[e1;e2]))
    (parse_right_infix "+" (fun (e1,e2) -> Fn("+",[e1;e2]))
       (parse_left_infix "-" (fun (e1,e2) -> Fn("-",[e1;e2]))
           (parse_right_infix "*" (fun (e1,e2) -> Fn("*",[e1;e2]))
                (parse_left_infix "^" (fun (e1,e2) -> Fn("^",[e1;e2]))
                   (parse_atomic_term vs))))) inp;;

let parset = make_parser (parse_term []);;

(* ------------------------------------------------------------------------- *)
(* Parsing of formulas.                                                      *)
(* ------------------------------------------------------------------------- *)

let parse_atom vs inp =
  try let tm,rest = parse_term vs inp in
      if exists (nextin rest) ["="; "<"; "<="; ">"; ">="] then
            papply (fun tm' -> Atom(R(hd rest,[tm;tm'])))
                   (parse_term vs (tl rest))
      else failwith ""
  with Failure _ ->
  match inp with
  | p::"("::")"::rest -> Atom(R(p,[])),rest
  | p::"("::rest ->
      papply (fun args -> Atom(R(p,args)))
             (parse_bracketed (parse_list "," (parse_term vs)) ")" rest)
  | p::rest when p <> "(" -> Atom(R(p,[])),rest
  | _ -> failwith "parse_atom";;

let parse = make_parser (parse_formula parse_atom []);;

(* ------------------------------------------------------------------------- *)
(* Set up parsing of quotations.                                             *)
(* ------------------------------------------------------------------------- *)

let default_parser = parse;;

let secondary_parser = parset;;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)

(*
<<(forall x. x < 2 ==> 2 * x <= 3) \/ false>>;;

<<|2 * x|>>;;
*)

(* ------------------------------------------------------------------------- *)
(* Printing of terms.                                                        *)
(* ------------------------------------------------------------------------- *)

let rec print_term prec fm =
  match fm with
    Var x -> print_string x
  | Fn("^",[tm1;tm2]) -> print_infix_term true prec 22 "^" tm1 tm2
  | Fn("*",[tm1;tm2]) -> print_infix_term false prec 20 "*" tm1 tm2
  | Fn("-",[tm1;tm2]) -> print_infix_term true prec 18 "-" tm1 tm2
  | Fn("+",[tm1;tm2]) -> print_infix_term false prec 16 "+" tm1 tm2
  | Fn("::",[tm1;tm2]) -> print_infix_term false prec 14 "::" tm1 tm2
  | Fn(f,args) -> print_fargs f args

and print_fargs f args =
  print_string f;
  if args = [] then () else
   (print_string "(";
    open_box 0;
    print_term 0 (hd args); print_break 0 0;
    do_list (fun t -> print_string ","; print_break 0 0; print_term 0 t)
            (tl args);
    close_box();
    print_string ")")

and print_infix_term isleft oldprec newprec sym p q =
  if oldprec > newprec then (print_string "("; open_box 0) else ();
  print_term (if isleft then newprec else newprec+1) p;
  print_string(" "^sym); print_space();
  print_term (if isleft then newprec+1 else newprec) q;
  if oldprec > newprec then (close_box(); print_string ")") else ();;

let printert fm = open_box 0; print_term 0 fm; close_box();;

(*
#install_printer printert;;
*)

(* ------------------------------------------------------------------------- *)
(* Printing of formulas.                                                     *)
(* ------------------------------------------------------------------------- *)

let print_atom prec (R(p,args)) =
  if mem p ["="; "<"; "<="; ">"; ">="] & length args = 2
  then print_infix_term false 12 12 p (el 0 args) (el 1 args)
  else print_fargs p args;;

let printer = formula_printer print_atom;;

(*
#install_printer printer;;
*)

(* ------------------------------------------------------------------------- *)
(* Examples in the main text.                                                *)
(* ------------------------------------------------------------------------- *)

(*
<<forall x y. exists z. x < z /\ y < z>>;;

<<~(forall x. P(x)) <=> exists y. ~P(y)>>;;
*)

(* ------------------------------------------------------------------------- *)
(* Model-theoretic notions, but here restricted to finite interpretations.   *)
(* ------------------------------------------------------------------------- *)

type ('a)interpretation =
  Interp of ('a)list *
            (string -> ('a)list -> 'a) *
            (string -> ('a)list -> bool);;

let domain(Interp(d,funs,preds)) = d
and func(Interp(d,funs,preds)) = funs
and predicate(Interp(d,funs,preds)) = preds;;

(* ------------------------------------------------------------------------- *)
(* Semantics.                                                                *)
(* ------------------------------------------------------------------------- *)

let rec termval md v tm =
  match tm with
    Var(x) -> apply v x
  | Fn(f,args) -> func(md) f (map (termval md v) args);;

let rec holds md v fm =
  match fm with
    False -> false
  | True -> true
  | Atom(R(r,args)) -> predicate(md) r (map (termval md v) args)
  | Not(p) -> not(holds md v p)
  | And(p,q) -> (holds md v p) & (holds md v q)
  | Or(p,q) -> (holds md v p) or (holds md v q)
  | Imp(p,q) -> not(holds md v p) or (holds md v q)
  | Iff(p,q) -> (holds md v p = holds md v q)
  | Forall(x,p) ->
        forall (fun a -> holds md ((x |-> a) v) p) (domain md)
  | Exists(x,p) ->
        exists (fun a -> holds md ((x |-> a) v) p) (domain md);;

(* ------------------------------------------------------------------------- *)
(* Examples of particular interpretations.                                   *)
(* ------------------------------------------------------------------------- *)

let bool_interp =
  let fns f args =
    match (f,args) with
      ("0",[]) -> false
    | ("1",[]) -> true
    | ("+",[x;y]) -> not(x = y)
    | ("*",[x;y]) -> x & y
    | _ -> failwith "uninterpreted function"
  and prs p args =
    match (p,args) with
      ("=",[x;y]) -> x = y
    | _ -> failwith "uninterpreted predicate" in
  Interp([false; true],fns,prs);;

let mod_interp n =
  let fns f args =
    match (f,args) with
      ("0",[]) -> 0
    | ("1",[]) -> 1 mod n
    | ("+",[x;y]) -> (x + y) mod n
    | ("*",[x;y]) -> (x * y) mod n
    | _ -> failwith "uninterpreted function"
  and prs p args =
    match (p,args) with
      ("=",[x;y]) -> x = y
    | _ -> failwith "uninterpreted predicate" in
  Interp(0 -- (n - 1),fns,prs);;

(*
let fm1 = <<forall x. (x = 0) \/ (x = 1)>>;;

holds bool_interp undefined fm1;;

holds (mod_interp 2) undefined fm1;;

holds (mod_interp 3) undefined fm1;;

let fm2 = <<forall x. ~(x = 0) ==> exists y. x * y = 1>>;;

holds bool_interp undefined fm2;;

holds (mod_interp 2) undefined fm2;;

holds (mod_interp 3) undefined fm2;;

holds (mod_interp 4) undefined fm2;;

holds (mod_interp 31) undefined fm2;;

holds (mod_interp 33) undefined fm2;;
*)

(* ------------------------------------------------------------------------- *)
(* Free variables in terms and formulas.                                     *)
(* ------------------------------------------------------------------------- *)

let rec fvt tm =
  match tm with
    Var x -> [x]
  | Fn(f,args) -> itlist (union ** fvt) args [];;

let rec fv fm =
  match fm with
    False -> []
  | True -> []
  | Atom(R(p,args)) -> itlist (union ** fvt) args []
  | Not(p) -> fv p
  | And(p,q) -> union (fv p) (fv q)
  | Or(p,q) -> union (fv p) (fv q)
  | Imp(p,q) -> union (fv p) (fv q)
  | Iff(p,q) -> union (fv p) (fv q)
  | Forall(x,p) -> subtract (fv p) [x]
  | Exists(x,p) -> subtract (fv p) [x];;

(* ------------------------------------------------------------------------- *)
(* Substitution within terms.                                                *)
(* ------------------------------------------------------------------------- *)

let instantiate vlist tlist =
  itlist2 (fun x t -> x |-> t) vlist tlist undefined;;

let rec termsubst sfn tm =
  match tm with
    Var x -> tryapplyd sfn x tm
  | Fn(f,args) -> Fn(f,map (termsubst sfn) args);;

(* ------------------------------------------------------------------------- *)
(* Incorrect substitution in formulas, and example showing why it's wrong.   *)
(* ------------------------------------------------------------------------- *)

let rec formsubst subfn fm =    (* WRONG! *)
  match fm with
    False -> False
  | True -> True
  | Atom(R(p,args)) -> Atom(R(p,map (termsubst subfn) args))
  | Not(p) -> Not(formsubst subfn p)
  | And(p,q) -> And(formsubst subfn p,formsubst subfn q)
  | Or(p,q) -> Or(formsubst subfn p,formsubst subfn q)
  | Imp(p,q) -> Imp(formsubst subfn p,formsubst subfn q)
  | Iff(p,q) -> Iff(formsubst subfn p,formsubst subfn q)
  | Forall(x,p) -> Forall(x,formsubst (undefine x subfn) p)
  | Exists(x,p) -> Exists(x,formsubst (undefine x subfn) p);;

formsubst ("y" := Var "x") <<forall x. x = y>>;;

(* ------------------------------------------------------------------------- *)
(* Variant function and examples.                                            *)
(* ------------------------------------------------------------------------- *)

let rec variant x vars =
  if mem x vars then variant (x^"'") vars else x;;

(*
variant "x" ["y"; "z"];;

variant "x" ["x"; "y"];;

variant "x" ["x"; "x'"];;
*)

(* ------------------------------------------------------------------------- *)
(* The correct version.                                                      *)
(* ------------------------------------------------------------------------- *)

let rec formsubst subfn fm =
  match fm with
    False -> False
  | True -> True
  | Atom(R(p,args)) -> Atom(R(p,map (termsubst subfn) args))
  | Not(p) -> Not(formsubst subfn p)
  | And(p,q) -> And(formsubst subfn p,formsubst subfn q)
  | Or(p,q) -> Or(formsubst subfn p,formsubst subfn q)
  | Imp(p,q) -> Imp(formsubst subfn p,formsubst subfn q)
  | Iff(p,q) -> Iff(formsubst subfn p,formsubst subfn q)
  | Forall(x,p) -> formsubstq subfn (fun (x,p) -> Forall(x,p)) (x,p)
  | Exists(x,p) -> formsubstq subfn (fun (x,p) -> Exists(x,p)) (x,p)

and formsubstq subfn quant (x,p) =
  let subfn' = undefine x subfn in
  let x' = if exists
             (fun y -> mem x (fvt(tryapplyd subfn' y (Var y))))
             (subtract (fv p) [x])
           then variant x (fv(formsubst subfn' p)) else x in
  quant(x',formsubst ((x |-> Var x') subfn) p);;

(* ------------------------------------------------------------------------- *)
(* Examples.                                                                 *)
(* ------------------------------------------------------------------------- *)

(*
formsubst ("y" := Var "x") <<forall x. x = y>>;;

formsubst ("y" := Var "x") <<forall x x'. x = y ==> x = x'>>;;
*)
(* ========================================================================= *)
(* Prenex and Skolem normal forms.                                           *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

(* ------------------------------------------------------------------------- *)
(* Routine simplification. Like "psimplify" but with quantifier clauses.     *)
(* ------------------------------------------------------------------------- *)

let simplify1 fm =
  match fm with
    Forall(x,p) -> if mem x (fv p) then fm else p
  | Exists(x,p) -> if mem x (fv p) then fm else p
  | _ -> psimplify1 fm;;

let rec simplify fm =
  match fm with
    Not p -> simplify1 (Not(simplify p))
  | And(p,q) -> simplify1 (And(simplify p,simplify q))
  | Or(p,q) -> simplify1 (Or(simplify p,simplify q))
  | Imp(p,q) -> simplify1 (Imp(simplify p,simplify q))
  | Iff(p,q) -> simplify1 (Iff(simplify p,simplify q))
  | Forall(x,p) -> simplify1(Forall(x,simplify p))
  | Exists(x,p) -> simplify1(Exists(x,simplify p))
  | _ -> fm;;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)

(*
simplify <<(forall x y. P(x) \/ (P(y) /\ false)) ==> exists z. P(z)>>;;
*)

(* ------------------------------------------------------------------------- *)
(* Negation normal form.                                                     *)
(* ------------------------------------------------------------------------- *)

let rec nnf fm =
  match fm with
    And(p,q) -> And(nnf p,nnf q)
  | Or(p,q) -> Or(nnf p,nnf q)
  | Imp(p,q) -> Or(nnf(Not p),nnf q)
  | Iff(p,q) -> Or(And(nnf p,nnf q),And(nnf(Not p),nnf(Not q)))
  | Not(Not p) -> nnf p
  | Not(And(p,q)) -> Or(nnf(Not p),nnf(Not q))
  | Not(Or(p,q)) -> And(nnf(Not p),nnf(Not q))
  | Not(Imp(p,q)) -> And(nnf p,nnf(Not q))
  | Not(Iff(p,q)) -> Or(And(nnf p,nnf(Not q)),And(nnf(Not p),nnf q))
  | Forall(x,p) -> Forall(x,nnf p)
  | Exists(x,p) -> Exists(x,nnf p)
  | Not(Forall(x,p)) -> Exists(x,nnf(Not p))
  | Not(Exists(x,p)) -> Forall(x,nnf(Not p))
  | _ -> fm;;

(* ------------------------------------------------------------------------- *)
(* Example of NNF function in action.                                        *)
(* ------------------------------------------------------------------------- *)

(*
let fm = <<(forall x. P(x))
           ==> ((exists y. Q(y)) <=> exists z. P(z) /\ Q(z))>> in
nnf fm;;

let andrews =
 <<((exists x. forall y. P(x) <=> P(y)) <=>
    ((exists x. Q(x)) <=> (forall y. Q(y)))) <=>
   ((exists x. forall y. Q(x) <=> Q(y)) <=>
    ((exists x. P(x)) <=> (forall y. P(y))))>> in
 nnf andrews;;
*)

(* ------------------------------------------------------------------------- *)
(* Prenex normal form.                                                       *)
(* ------------------------------------------------------------------------- *)

let mk_all x p = Forall(x,p) and mk_ex x p = Exists(x,p);;
let mk_and p q = And(p,q) and mk_or p q = Or(p,q);;

let rec pullquants fm =
  match fm with
    And(Forall(x,p),Forall(y,q)) -> pullquant_2 fm mk_all mk_and x y p q
  | Or(Exists(x,p),Exists(y,q)) -> pullquant_2 fm mk_ex mk_or x y p q
  | And(Forall(x,p),q) -> pullquant_l fm mk_all mk_and x p q
  | And(p,Forall(x,q)) -> pullquant_r fm mk_all mk_and x p q
  | Or(Forall(x,p),q) -> pullquant_l fm mk_all mk_or x p q
  | Or(p,Forall(x,q)) -> pullquant_r fm mk_all mk_or x p q
  | And(Exists(x,p),q) -> pullquant_l fm mk_ex mk_and x p q
  | And(p,Exists(x,q)) -> pullquant_r fm mk_ex mk_and x p q
  | Or(Exists(x,p),q) -> pullquant_l fm mk_ex mk_or x p q
  | Or(p,Exists(x,q)) -> pullquant_r fm mk_ex mk_or x p q
  | _ -> fm

and pullquant_l fm quant op x p q =
  let x' = variant x (fv fm) in
  quant x' (pullquants(op (formsubst (x := Var x') p) q))

and pullquant_r fm quant op x p q =
  let x' = variant x (fv fm) in
  quant x' (pullquants(op p (formsubst (x := Var x') q)))

and pullquant_2 fm quant op x y p q =
  let x' = variant x (fv fm) in
  quant x' (pullquants(op (formsubst (x := Var x') p)
                          (formsubst (y := Var x') q)));;

let rec prenex fm =
  match fm with
    Forall(x,p) -> Forall(x,prenex p)
  | Exists(x,p) -> Exists(x,prenex p)
  | And(p,q) -> pullquants(And(prenex p,prenex q))
  | Or(p,q) -> pullquants(Or(prenex p,prenex q))
  | _ -> fm;;

let pnf fm = prenex(nnf(simplify fm));;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)

(*
pnf <<forall x. P(x) ==> exists y z. Q(y) \/ ~(exists z. P(z) /\ Q(z))>>;;
*)

(* ------------------------------------------------------------------------- *)
(* Get the functions in a term and formula.                                  *)
(* ------------------------------------------------------------------------- *)

let rec funcs tm = match tm with
    Var x -> []
  | Fn(f,args) -> itlist (union ** funcs) args [f,length args];;

let functions fm =
  atom_union (fun (R(p,a)) -> itlist (union ** funcs) a []) fm;;

(* ------------------------------------------------------------------------- *)
(* Core Skolemization function.                                              *)
(* ------------------------------------------------------------------------- *)

let rec skolem fm corr =
  match fm with
    Exists(y,p) ->
        let xs = fv(fm)
        and fns = map (fun (Fn(f,args),def) -> f) corr in
        let f = variant (if xs = [] then "c_"^y else "f_"^y) fns in
        let fx = Fn(f,map (fun x -> Var x) xs) in
        skolem (formsubst (y := fx) p) ((fx,fm)::corr)
  | Forall(x,p) -> let p',corr' = skolem p corr in Forall(x,p'),corr'
  | And(p,q) -> skolem2 (fun (p,q) -> And(p,q)) (p,q) corr
  | Or(p,q) -> skolem2 (fun (p,q) -> Or(p,q)) (p,q) corr
  | _ -> fm,corr

and skolem2 cons (p,q) corr =
  let p',corr' = skolem p corr in
  let q',corr'' = skolem q corr' in
  cons(p',q'),corr'';;

(* ------------------------------------------------------------------------- *)
(* Overall Skolemization function.                                           *)
(* ------------------------------------------------------------------------- *)

let askolemize fm =
  let fm1 = nnf(simplify fm) in
  let corr = map (fun (n,a) -> Fn(n,[]),False) (functions fm1) in
  fst(skolem fm1 corr);;

let rec specialize fm =
  match fm with
    Forall(x,p) -> specialize p
  | _ -> fm;;

let skolemize fm = specialize(pnf(askolemize fm));;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)

(*
skolemize <<exists y. x < y ==> forall u. exists v. x * u < y * v>>;;

skolemize
 <<forall x. P(x)
             ==> (exists y z. Q(y) \/ ~(exists z. P(z) /\ Q(z)))>>;;
*)
(* ========================================================================= *)
(* Relation between FOL and propositonal logic; Herbrand theorem.            *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

(* ------------------------------------------------------------------------- *)
(* We want to generalize a formula before negating and refuting.             *)
(* ------------------------------------------------------------------------- *)

let generalize fm = itlist (fun x p -> Forall(x,p)) (fv fm) fm;;

(* ------------------------------------------------------------------------- *)
(* Propositional valuation.                                                  *)
(* ------------------------------------------------------------------------- *)

let pholds d fm = eval fm (fun p -> d(Atom p));;

(* ------------------------------------------------------------------------- *)
(* Characteristic function of Herbrand interpretation for p.                 *)
(* ------------------------------------------------------------------------- *)

let rec herbdom funcs tm =
  match tm with
    Var _ -> false
  | Fn(f,a) -> mem (f,length a) funcs & forall (herbdom funcs) a;;

(* ------------------------------------------------------------------------- *)
(* Get the constants for Herbrand base, adding nullary one if necessary.     *)
(* ------------------------------------------------------------------------- *)

let herbfuns fm =
  let cns,fns = partition (fun (_,ar) -> ar = 0) (functions fm) in
  if cns = [] then ["c",0],fns else cns,fns;;

(* ------------------------------------------------------------------------- *)
(* Enumeration of ground terms and m-tuples, ordered by total fns.           *)
(* ------------------------------------------------------------------------- *)

let rec groundterms cntms funcs n =
  if n = 0 then cntms else
  itlist (fun (f,m) -> (@)
                (map (fun args -> Fn(f,args))
                     (groundtuples cntms funcs (n - 1) m)))
         funcs []

and groundtuples cntms funcs n m =
  if m = 0 then if n = 0 then [[]] else [] else
  itlist (fun k -> (@)
                (allpairs (fun h t -> h::t)
                          (groundterms cntms funcs k)
                          (groundtuples cntms funcs (n - k) (m - 1))))
         (0 -- n) [];;

(* ------------------------------------------------------------------------- *)
(* Iterate modifier "mfn" over ground terms till "tfn" fails.                *)
(* ------------------------------------------------------------------------- *)

let rec herbloop mfn tfn fl0 cntms funcs fvs n fl tried tuples =
  print_string(string_of_int(length tried)^" ground instances tried; "^
               string_of_int(length fl)^" items in list");
  print_newline();
  match tuples with
    [] -> let newtups = groundtuples cntms funcs n (length fvs) in
          herbloop mfn tfn fl0 cntms funcs fvs (n + 1) fl tried newtups
  | tup::tups ->
          let fl' = mfn fl0 (formsubst(instantiate fvs tup)) fl in
          if not(tfn fl') then tup::tried else
          herbloop mfn tfn fl0 cntms funcs fvs n fl' (tup::tried) tups;;

(* ------------------------------------------------------------------------- *)
(* Hence a simple Gilmore-type procedure.                                    *)
(* ------------------------------------------------------------------------- *)

let gilmore_loop =
  let mfn djs0 ifn djs =
    filter (non contradictory) (distrib (smap (smap ifn) djs0) djs) in
  herbloop mfn (fun djs -> djs <> []);;

let gilmore fm =
  let sfm = skolemize(Not(generalize fm)) in
  let fvs = fv sfm and consts,funcs = herbfuns sfm in
  let cntms = smap (fun (c,_) -> Fn(c,[])) consts in
  length(gilmore_loop (simpdnf sfm) cntms funcs fvs 0 [[]] [] []);;

(* ------------------------------------------------------------------------- *)
(* First example and a little tracing.                                       *)
(* ------------------------------------------------------------------------- *)

(*
gilmore <<exists x. forall y. P(x) ==> P(y)>>;;

let sfm = skolemize(Not <<exists x. forall y. P(x) ==> P(y)>>);;

(* ------------------------------------------------------------------------- *)
(* Quick examples.                                                           *)
(* ------------------------------------------------------------------------- *)

let p19 = gilmore
 <<exists x. forall y z. (P(y) ==> Q(z)) ==> P(x) ==> Q(x)>>;;

let p24 = gilmore
 <<~(exists x. U(x) /\ Q(x)) /\
   (forall x. P(x) ==> Q(x) \/ R(x)) /\
   ~(exists x. P(x) ==> (exists x. Q(x))) /\
   (forall x. Q(x) /\ R(x) ==> U(x))
   ==> (exists x. P(x) /\ R(x))>>;;

let p39 = gilmore
 <<~(exists x. forall y. P(y,x) <=> ~P(y,y))>>;;

let p42 = gilmore
 <<~(exists y. forall x. P(x,y) <=> ~(exists z. P(x,z) /\ P(z,x)))>>;;

let p44 = gilmore
 <<(forall x. P(x) ==> (exists y. G(y) /\ H(x,y)) /\
   (exists y. G(y) /\ ~H(x,y))) /\
   (exists x. J(x) /\ (forall y. G(y) ==> H(x,y)))
   ==> (exists x. J(x) /\ ~P(x))>>;;

let p59 = gilmore
 <<(forall x. P(x) <=> ~P(f(x))) ==> (exists x. P(x) /\ ~P(f(x)))>>;;

(* ------------------------------------------------------------------------- *)
(* Slightly less easy examples.                                              *)
(* ------------------------------------------------------------------------- *)

let p45 = gilmore
 <<(forall x.
     P(x) /\ (forall y. G(y) /\ H(x,y) ==> J(x,y))
     ==> (forall y. G(y) /\ H(x,y) ==> R(y))) /\
   ~(exists y. L(y) /\ R(y)) /\
   (exists x. P(x) /\ (forall y. H(x,y) ==> L(y)) /\
                      (forall y. G(y) /\ H(x,y) ==> J(x,y)))
   ==> (exists x. P(x) /\ ~(exists y. G(y) /\ H(x,y)))>>;;

let p60 = gilmore
 <<forall x. P(x,f(x)) <=>
             exists y. (forall z. P(z,y) ==> P(z,f(x))) /\ P(x,y)>>;;

let p43 = gilmore
 <<(forall x y. Q(x,y) <=> forall z. P(z,x) <=> P(z,y))
   ==> forall x y. Q(x,y) <=> Q(y,x)>>;;

*)

(* ------------------------------------------------------------------------- *)
(* The Davis-Putnam procedure for first order logic.                         *)
(* ------------------------------------------------------------------------- *)

let clausal fm = smap (smap negate) (simpdnf(nnf(Not fm)));;

let dp_mfn cjs0 ifn cjs = union (smap (smap ifn) cjs0) cjs;;

let dp_loop = herbloop dp_mfn dpll;;

let davisputnam fm =
  let sfm = skolemize(Not(generalize fm)) in
  if sfm = False then 0
  else if sfm = True then failwith "davisputnam" else
  let fvs = fv sfm and consts,funcs = herbfuns sfm in
  let cntms = smap (fun (c,_) -> Fn(c,[])) consts in
  length(dp_loop (clausal sfm) cntms funcs fvs 0 [] [] []);;

(* ------------------------------------------------------------------------- *)
(* Show how much better than the Gilmore procedure this can be.              *)
(* ------------------------------------------------------------------------- *)

(*
let p20 = davisputnam
 <<(forall x y. exists z. forall w. P(x) /\ Q(y) ==> R(z) /\ U(w))
   ==> (exists x y. P(x) /\ Q(y)) ==> (exists z. R(z))>>;;
*)

(* ------------------------------------------------------------------------- *)
(* Show the sensitivity to order: try also with variant suggested.           *)
(* ------------------------------------------------------------------------- *)

let rec herbloop' mfn tfn fl0 cntms funcs fvs n fl tried tuples =
  print_string(string_of_int(length tried)^" ground instances tried; "^
               string_of_int(length fl)^" items in list");
  print_newline();
  match tuples with
    [] -> let newtups = rev(groundtuples cntms funcs n (length fvs)) in
          herbloop' mfn tfn fl0 cntms funcs fvs (n + 1) fl tried newtups
  | tup::tups ->
          let fl' = mfn fl0 (formsubst(instantiate fvs tup)) fl in
          if not(tfn fl') then tup::tried else
          herbloop' mfn tfn fl0 cntms funcs fvs n fl' (tup::tried) tups;;

let dp_loop' =
  herbloop' (fun cjs0 ifn cjs -> union (smap (smap ifn) cjs0) cjs) dpll;;

let davisputnam' fm =
  let sfm = skolemize(Not(generalize fm)) in
  if sfm = False then 0
  else if sfm = True then failwith "davisputnam" else
  let fvs = fv sfm and consts,funcs = herbfuns sfm in
  let cntms = smap (fun (c,_) -> Fn(c,[])) consts in
  length(dp_loop' (clausal sfm) cntms funcs fvs 0 [] [] []);;

(*
let p36 = davisputnam
 <<(forall x. exists y. P(x,y)) /\
   (forall x. exists y. G(x,y)) /\
   (forall x y. P(x,y) \/ G(x,y)
                ==> (forall z. P(y,z) \/ G(y,z) ==> H(x,z)))
   ==> (forall x. exists y. H(x,y))>>;;

let p36 = davisputnam'
 <<(forall x. exists y. P(x,y)) /\
   (forall x. exists y. G(x,y)) /\
   (forall x y. P(x,y) \/ G(x,y)
                ==> (forall z. P(y,z) \/ G(y,z) ==> H(x,z)))
   ==> (forall x. exists y. H(x,y))>>;;

let p29 = davisputnam
 <<(exists x. P(x)) /\ (exists x. G(x)) ==>
   ((forall x. P(x) ==> H(x)) /\ (forall x. G(x) ==> J(x)) <=>
    (forall x y. P(x) /\ G(y) ==> H(x) /\ J(y)))>>;;

let p29 = davisputnam'
 <<(exists x. P(x)) /\ (exists x. G(x)) ==>
   ((forall x. P(x) ==> H(x)) /\ (forall x. G(x) ==> J(x)) <=>
    (forall x y. P(x) /\ G(y) ==> H(x) /\ J(y)))>>;;
*)

(* ------------------------------------------------------------------------- *)
(* Try to cut out useless instantiations in final result.                    *)
(* ------------------------------------------------------------------------- *)

let rec dp_refine cjs0 fvs dunno need =
  match dunno with
    [] -> need
  | cl::dknow ->
      let mfn = dp_mfn cjs0 ** formsubst ** instantiate fvs in
      let need' =
       if dpll(itlist mfn (need @ dknow) []) then cl::need else need in
      dp_refine cjs0 fvs dknow need';;

let dp_refine_loop cjs0 cntms funcs fvs n cjs tried tuples =
  let tups = dp_loop cjs0 cntms funcs fvs n cjs tried tuples in
  dp_refine cjs0 fvs tups [];;

(* ------------------------------------------------------------------------- *)
(* Show how few of the instances we really need. Hence unification!          *)
(* ------------------------------------------------------------------------- *)

let davisputnam'' fm =
  let sfm = skolemize(Not(generalize fm)) in
  if sfm = False then 0
  else if sfm = True then failwith "davisputnam" else
  let fvs = fv sfm and consts,funcs = herbfuns sfm in
  let cntms = smap (fun (c,_) -> Fn(c,[])) consts in
  length(dp_refine_loop (clausal sfm) cntms funcs fvs 0 [] [] []);;

(*
let p36 = davisputnam''
 <<(forall x. exists y. P(x,y)) /\
   (forall x. exists y. G(x,y)) /\
   (forall x y. P(x,y) \/ G(x,y)
                ==> (forall z. P(y,z) \/ G(y,z) ==> H(x,z)))
   ==> (forall x. exists y. H(x,y))>>;;

let p29 = davisputnam''
 <<(exists x. P(x)) /\ (exists x. G(x)) ==>
   ((forall x. P(x) ==> H(x)) /\ (forall x. G(x) ==> J(x)) <=>
    (forall x y. P(x) /\ G(y) ==> H(x) /\ J(y)))>>;;
*)
(* ========================================================================= *)
(* Unification for first order terms.                                        *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

let rec istriv env x t =
  match t with
    Var y -> y = x or defined env y & istriv env x (apply env y)
  | Fn(f,args) -> exists (istriv env x) args & failwith "cyclic";;

(* ------------------------------------------------------------------------- *)
(* Main unification procedure                                                *)
(* ------------------------------------------------------------------------- *)

let rec unify env eqs =
  match eqs with
    [] -> env
  | (Fn(f,fargs),Fn(g,gargs))::oth ->
        if f = g & length fargs = length gargs
        then unify env (zip fargs gargs @ oth)
        else failwith "impossible unification"
  | (Var x,t)::oth ->
        if defined env x then unify env ((apply env x,t)::oth)
        else unify (if istriv env x t then env else (x|->t) env) oth
  | (t,Var x)::oth -> unify env ((Var x,t)::oth);;

(* ------------------------------------------------------------------------- *)
(* Unification reaching a final solved form (often this isn't needed).       *)
(* ------------------------------------------------------------------------- *)

let solve =
  let rec solve (env,fs) =
    if exists (fun (x,t) -> mem x (fvt t)) fs
    then failwith "solve: cyclic" else
    let env' =
      itlist (fun (x,t) -> x |-> termsubst env t) fs undefined in
    let fs' = funset env' in
    if fs = fs' then env else solve (env',fs') in
  fun env -> solve(env,funset env);;

let fullunify eqs = solve (unify undefined eqs);;

(* ------------------------------------------------------------------------- *)
(* Examples.                                                                 *)
(* ------------------------------------------------------------------------- *)

let unify_and_apply eqs =
  let i = fullunify eqs in
  let apply (t1,t2) = termsubst i t1,termsubst i t2 in
  map apply eqs;;

(*
unify_and_apply [<<|f(x,g(y))|>>,<<|f(f(z),w)|>>];;

unify_and_apply [<<|f(x,y)|>>,<<|f(y,x)|>>];;

unify_and_apply [<<|x_0|>>,<<|f(x_1,x_1)|>>;
                 <<|x_1|>>,<<|f(x_2,x_2)|>>;
                 <<|x_2|>>,<<|f(x_3,x_3)|>>];;
*)
(* ========================================================================= *)
(* Tableaux, seen as an optimized version of a Prawitz-like procedure.       *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

(* ------------------------------------------------------------------------- *)
(* Unify literals (just pretend the toplevel relation is a function).        *)
(* ------------------------------------------------------------------------- *)

let rec unify_literals env =
  function (Atom(R(p1,a1)),Atom(R(p2,a2))) ->
                       unify env [Fn(p1,a1),Fn(p2,a2)]
         | (Not(p),Not(q)) -> unify_literals env (p,q)
         | (False,False) -> env
         | _ -> failwith "Can't unify literals";;

(* ------------------------------------------------------------------------- *)
(* Unify complementary literals.                                             *)
(* ------------------------------------------------------------------------- *)

let unify_complements env (p,q) = unify_literals env (p,negate q);;

(* ------------------------------------------------------------------------- *)
(* Unify and refute a set of disjuncts.                                      *)
(* ------------------------------------------------------------------------- *)

let rec unify_refute djs env =
  match djs with
    [] -> env
  | cjs::odjs -> let pos,neg = partition positive cjs in
                 tryfind (unify_refute odjs ** unify_complements env)
                         (allpairs (fun p q -> (p,q)) pos neg);;

(* ------------------------------------------------------------------------- *)
(* Hence a Prawitz-like procedure (using unification on DNF).                *)
(* ------------------------------------------------------------------------- *)

let rec prawitz_loop djs0 fvs djs n =
  let newvars =
    map (fun k -> "_" ^ string_of_int (n + k)) (1 -- length fvs) in
  let inst = instantiate fvs (map (fun x -> Var x) newvars) in
  let djs1 = distrib (smap (smap (formsubst inst)) djs0) djs in
  try unify_refute djs1 undefined,(n / length fvs + 1)
  with Failure _ -> prawitz_loop djs0 fvs djs1 (n + length fvs);;

let prawitz fm =
  let fm0 = skolemize(Not(generalize fm)) in
  if fm0 = False then 0
  else if fm0 = True then failwith "prawitz" else
  snd(prawitz_loop (simpdnf fm0) (fv fm0) [[]] 0);;

(* ------------------------------------------------------------------------- *)
(* Examples.                                                                 *)
(* ------------------------------------------------------------------------- *)

(*
let p20 = prawitz
 <<(forall x y. exists z. forall w. P(x) /\ Q(y) ==> R(z) /\ U(w))
   ==> (exists x y. P(x) /\ Q(y)) ==> (exists z. R(z))>>;;
*)

(* ------------------------------------------------------------------------- *)
(* Comparison of number of ground instances.                                 *)
(* ------------------------------------------------------------------------- *)

let compare fm =
  prawitz fm,davisputnam fm;;

(*
let p19 = compare
 <<exists x. forall y z. (P(y) ==> Q(z)) ==> P(x) ==> Q(x)>>;;

let p20 = compare
 <<(forall x y. exists z. forall w. P(x) /\ Q(y) ==> R(z) /\ U(w))
   ==> (exists x y. P(x) /\ Q(y)) ==> (exists z. R(z))>>;;

let p24 = compare
 <<~(exists x. U(x) /\ Q(x)) /\
   (forall x. P(x) ==> Q(x) \/ R(x)) /\
   ~(exists x. P(x) ==> (exists x. Q(x))) /\
   (forall x. Q(x) /\ R(x) ==> U(x))
   ==> (exists x. P(x) /\ R(x))>>;;

let p39 = compare
 <<~(exists x. forall y. P(y,x) <=> ~P(y,y))>>;;

let p42 = compare
 <<~(exists y. forall x. P(x,y) <=> ~(exists z. P(x,z) /\ P(z,x)))>>;;

let p44 = compare
 <<(forall x. P(x) ==> (exists y. G(y) /\ H(x,y)) /\
   (exists y. G(y) /\ ~H(x,y))) /\
   (exists x. J(x) /\ (forall y. G(y) ==> H(x,y)))
   ==> (exists x. J(x) /\ ~P(x))>>;;

let p59 = compare
 <<(forall x. P(x) <=> ~P(f(x))) ==> (exists x. P(x) /\ ~P(f(x)))>>;;

let p60 = compare
 <<forall x. P(x,f(x)) <=>
             exists y. (forall z. P(z,y) ==> P(z,f(x))) /\ P(x,y)>>;;

*)

(* ------------------------------------------------------------------------- *)
(* More standard tableau procedure, effectively doing DNF incrementally.     *)
(* ------------------------------------------------------------------------- *)

let rec tableau (fms,lits,n) cont (env,k) =
  if n < 0 then failwith "no proof at this level" else
  match fms with
    [] -> failwith "tableau: no proof"
  | And(p,q)::unexp ->
      tableau (p::q::unexp,lits,n) cont (env,k)
  | Or(p,q)::unexp ->
      tableau (p::unexp,lits,n) (tableau (q::unexp,lits,n) cont) (env,k)
  | Forall(x,p)::unexp ->
      let y = Var("_" ^ string_of_int k) in
      let p' = formsubst (x := y) p in
      tableau (p'::unexp@[Forall(x,p)],lits,n-1) cont (env,k+1)
  | fm::unexp ->
      try tryfind (fun l -> cont(unify_complements env (fm,l),k)) lits
      with Failure _ -> tableau (unexp,fm::lits,n) cont (env,k);;

let rec deepen f n =
  try print_string "Searching with depth limit ";
      print_int n; print_newline(); f n
  with Failure _ -> deepen f (n + 1);;

let tabrefute fms =
  deepen (fun n -> tableau (fms,[],n) (fun x -> x) (undefined,0); n) 0;;

let tab fm =
  let sfm = askolemize(Not(generalize fm)) in
  if sfm = False then 0
  else if sfm = True then failwith "tab: no proof"
  else tabrefute [sfm];;

(* ------------------------------------------------------------------------- *)
(* Examples.                                                                 *)
(* ------------------------------------------------------------------------- *)

(*
let p38 = tab
 <<(forall x.
     P(a) /\ (P(x) ==> (exists y. P(y) /\ R(x,y))) ==>
     (exists z w. P(z) /\ R(x,w) /\ R(w,z))) <=>
   (forall x.
     (~P(a) \/ P(x) \/ (exists z w. P(z) /\ R(x,w) /\ R(w,z))) /\
     (~P(a) \/ ~(exists y. P(y) /\ R(x,y)) \/
     (exists z w. P(z) /\ R(x,w) /\ R(w,z))))>>;;

let p45 = tab
 <<(forall x.
     P(x) /\ (forall y. G(y) /\ H(x,y) ==> J(x,y)) ==>
       (forall y. G(y) /\ H(x,y) ==> R(y))) /\
   ~(exists y. L(y) /\ R(y)) /\
   (exists x. P(x) /\ (forall y. H(x,y) ==>
     L(y)) /\ (forall y. G(y) /\ H(x,y) ==> J(x,y))) ==>
   (exists x. P(x) /\ ~(exists y. G(y) /\ H(x,y)))>>;;

let gilmore_9 = tab
 <<forall x. exists y. forall z.
     ((forall u. exists v. F(y,u,v) /\ G(y,u) /\ ~H(y,x))
       ==> (forall u. exists v. F(x,u,v) /\ G(z,u) /\ ~H(x,z))
          ==> (forall u. exists v. F(x,u,v) /\ G(y,u) /\ ~H(x,y))) /\
     ((forall u. exists v. F(x,u,v) /\ G(y,u) /\ ~H(x,y))
      ==> ~(forall u. exists v. F(x,u,v) /\ G(z,u) /\ ~H(x,z))
          ==> (forall u. exists v. F(y,u,v) /\ G(y,u) /\ ~H(y,x)) /\
              (forall u. exists v. F(z,u,v) /\ G(y,u) /\ ~H(z,y)))>>;;
*)

(* ------------------------------------------------------------------------- *)
(* Try to split up the initial formula first; often a big improvement.       *)
(* ------------------------------------------------------------------------- *)

let splittab fm =
  map tabrefute (simpdnf(askolemize(Not(generalize fm))));;

(* ------------------------------------------------------------------------- *)
(* Examples.                                                                 *)
(* ------------------------------------------------------------------------- *)

(*
let p34 = splittab
 <<((exists x. forall y. P(x) <=> P(y)) <=>
    ((exists x. Q(x)) <=> (forall y. Q(y)))) <=>
    ((exists x. forall y. Q(x) <=> Q(y)) <=>
   ((exists x. P(x)) <=> (forall y. P(y))))>>;;

let p46 = splittab
 <<(forall x. P(x) /\ (forall y. P(y) /\ H(y,x) ==> G(y)) ==> G(x)) /\
    ((exists x. P(x) /\ ~G(x)) ==>
     (exists x. P(x) /\ ~G(x) /\
                (forall y. P(y) /\ ~G(y) ==> J(x,y)))) /\
    (forall x y. P(x) /\ P(y) /\ H(x,y) ==> ~J(y,x)) ==>
    (forall x. P(x) ==> G(x))>>;;

(* ------------------------------------------------------------------------- *)
(* Another nice example from EWD 1602.                                       *)
(* ------------------------------------------------------------------------- *)

let ewd1062 = splittab
 <<(forall x. x <= x) /\
   (forall x y z. x <= y /\ y <= z ==> x <= z) /\
   (forall x y. f(x) <= y <=> x <= g(y))
   ==> (forall x y. x <= y ==> f(x) <= f(y)) /\
       (forall x y. x <= y ==> g(x) <= g(y))>>;;

(* ------------------------------------------------------------------------- *)
(* Well-known "Agatha" example; cf. Manthey and Bry, CADE-9.                 *)
(* ------------------------------------------------------------------------- *)

let p55 = time splittab
 <<lives(agatha) /\ lives(butler) /\ lives(charles) /\
   (killed(agatha,agatha) \/ killed(butler,agatha) \/
    killed(charles,agatha)) /\
   (forall x y. killed(x,y) ==> hates(x,y) /\ ~richer(x,y)) /\
   (forall x. hates(agatha,x) ==> ~hates(charles,x)) /\
   (hates(agatha,agatha) /\ hates(agatha,charles)) /\
   (forall x. lives(x) /\ ~richer(x,agatha) ==> hates(butler,x)) /\
   (forall x. hates(agatha,x) ==> hates(butler,x)) /\
   (forall x. ~hates(x,agatha) \/ ~hates(x,butler) \/ ~hates(x,charles))
   ==> killed(agatha,agatha) /\
       ~killed(butler,agatha) /\
       ~killed(charles,agatha)>>;;

(* ------------------------------------------------------------------------- *)
(* Example from Davis-Putnam papers where Gilmore procedure is poor.         *)
(* ------------------------------------------------------------------------- *)

let davis_putnam_example = time splittab
 <<exists x. exists y. forall z.
        (F(x,y) ==> (F(y,z) /\ F(z,z))) /\
        ((F(x,y) /\ G(x,y)) ==> (G(x,z) /\ G(z,z)))>>;;

*)
(* ========================================================================= *)
(* Resolution.                                                               *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

(* ------------------------------------------------------------------------- *)
(* Barber's paradox is an example of why we need factoring.                  *)
(* ------------------------------------------------------------------------- *)

let barb = <<~(exists b. forall x. shaves(b,x) <=> ~shaves(x,x))>>;;

(*
clausal(skolemize(Not barb));;
*)

(* ------------------------------------------------------------------------- *)
(* MGU of a set of literals.                                                 *)
(* ------------------------------------------------------------------------- *)

let rec mgu l env =
  match l with
    a::b::rest -> mgu (b::rest) (unify_literals env (a,b))
  | _ -> solve env;;

let unifiable p q = can (unify_literals undefined) (p,q);;

(* ------------------------------------------------------------------------- *)
(* Rename a clause.                                                          *)
(* ------------------------------------------------------------------------- *)

let rename pfx cls =
  let fvs = fv(list_disj cls) in
  let vvs = map (fun s -> Var(pfx^s)) fvs  in
  map (formsubst(instantiate fvs vvs)) cls;;

(* ------------------------------------------------------------------------- *)
(* General resolution rule, incorporating factoring as in Robinson's paper.  *)
(* ------------------------------------------------------------------------- *)

let resolvents cl1 cl2 p acc =
  let ps2 = filter (unifiable(negate p)) cl2 in
  if ps2 = [] then acc else
  let ps1 = filter (fun q -> q <> p & unifiable p q) cl1 in
  let pairs = allpairs (fun s1 s2 -> s1,s2)
                       (map (fun pl -> p::pl) (allsubsets ps1))
                       (allnonemptysubsets ps2) in
  itlist (fun (s1,s2) sof ->
           try smap (formsubst (mgu (s1 @ map negate s2) undefined))
                    (union (subtract cl1 s1) (subtract cl2 s2)) :: sof
           with Failure _ -> sof) pairs acc;;

let resolve_clauses cls1 cls2 =
  let cls1' = rename "x" cls1 and cls2' = rename "y" cls2 in
  itlist (resolvents cls1' cls2') cls1' [];;

(* ------------------------------------------------------------------------- *)
(* Basic "Argonne" loop.                                                     *)
(* ------------------------------------------------------------------------- *)

let rec resloop (used,unused) =
  match unused with
    [] -> failwith "No proof found"
  | cls::ros ->
        print_string(string_of_int(length used) ^ " used; "^
                     string_of_int(length unused) ^ " unused.");
        print_newline();
        let used' = insert cls used in
        let news =
          itlist (@) (mapfilter (resolve_clauses cls) used') [] in
        if mem [] news then true else resloop (used',ros@news);;

let pure_resolution fm =
  resloop([],clausal(specialize(pnf fm)));;

let resolution fm =
  let fm1 = askolemize(Not(generalize fm)) in
  map (pure_resolution ** list_conj) (simpdnf fm1);;

(* ------------------------------------------------------------------------- *)
(* Simple example that works well.                                           *)
(* ------------------------------------------------------------------------- *)

(*
let davis_putnam_example = resolution
 <<exists x. exists y. forall z.
        (F(x,y) ==> (F(y,z) /\ F(z,z))) /\
        ((F(x,y) /\ G(x,y)) ==> (G(x,z) /\ G(z,z)))>>;;
*)

(* ------------------------------------------------------------------------- *)
(* Tautology checking for a clause (re-using dual code).                     *)
(* ------------------------------------------------------------------------- *)

let tautologous = contradictory;;

(* ------------------------------------------------------------------------- *)
(* Test for subsumption                                                      *)
(* ------------------------------------------------------------------------- *)

let subsumes_clause cls1 cls2 =
  let fvs = itlist (union ** fv) cls2 [] in
  let ifn = formsubst
    (itlist (fun x -> x |-> Fn("_"^x,[])) fvs undefined) in
  let cls2' = map ifn cls2 in
  let rec subsume env cls =
    match cls with
      [] -> env
    | l1::cls' ->
        tryfind (fun l2 -> subsume (unify_literals env (l1,l2)) cls')
                cls2' in
  can (subsume undefined) cls1;;

(* ------------------------------------------------------------------------- *)
(* With deletion of tautologies and bi-subsumption with "unused".            *)
(* ------------------------------------------------------------------------- *)

let rec replace cl lis =
  match lis with
    [] -> [cl]
  | c::cls -> if subsumes_clause cl c then cl::cls
              else c::(replace cl cls);;

let incorporate gcl cl unused =
  if tautologous cl or
     exists (fun c -> subsumes_clause c cl) (gcl::unused)
  then unused else replace cl unused;;

let rec resloop (used,unused) =
  match unused with
    [] -> failwith "No proof found"
  | cls::ros ->
        print_string(string_of_int(length used) ^ " used; "^
                     string_of_int(length unused) ^ " unused.");
        print_newline();
        let used' = insert cls used in
        let news =
          itlist (@) (mapfilter (resolve_clauses cls) used') [] in
        if mem [] news then true else
        resloop(used',itlist (incorporate cls) news ros);;

let pure_resolution fm =
  resloop([],filter (non tautologous) (clausal(specialize(pnf fm))));;

let resolution fm =
  let fm1 = askolemize(Not(generalize fm)) in
  map (pure_resolution ** list_conj) (simpdnf fm1);;

(* ------------------------------------------------------------------------- *)
(* This is now a lot quicker.                                                *)
(* ------------------------------------------------------------------------- *)

(*
let davis_putnam_example = resolution
 <<exists x. exists y. forall z.
        (F(x,y) ==> (F(y,z) /\ F(z,z))) /\
        ((F(x,y) /\ G(x,y)) ==> (G(x,z) /\ G(z,z)))>>;;
*)

(* ------------------------------------------------------------------------- *)
(* Positive (P1) resolution.                                                 *)
(* ------------------------------------------------------------------------- *)

let presolve_clauses cls1 cls2 =
  if forall positive cls1 or forall positive cls2
  then resolve_clauses  cls1 cls2 else [];;

let rec presloop (used,unused) =
  match unused with
    [] -> failwith "No proof found"
  | cls::ros ->
        print_string(string_of_int(length used) ^ " used; "^
                     string_of_int(length unused) ^ " unused.");
        print_newline();
        let used' = insert cls used in
        let news =
          itlist (@) (mapfilter (presolve_clauses cls) used') [] in
        if mem [] news then true else
        presloop(used',itlist (incorporate cls) news ros);;

let pure_presolution fm =
  presloop([],filter (non tautologous) (clausal(specialize(pnf fm))));;

let presolution fm =
  let fm1 = askolemize(Not(generalize fm)) in
  map (pure_presolution ** list_conj) (simpdnf fm1);;

(* ------------------------------------------------------------------------- *)
(* Example: the (in)famous Los problem.                                      *)
(* ------------------------------------------------------------------------- *)

(*
let los = time presolution
 <<(forall x y z. P(x,y) ==> P(y,z) ==> P(x,z)) /\
   (forall x y z. Q(x,y) ==> Q(y,z) ==> Q(x,z)) /\
   (forall x y. Q(x,y) ==> Q(y,x)) /\
   (forall x y. P(x,y) \/ Q(x,y))
   ==> (forall x y. P(x,y)) \/ (forall x y. Q(x,y))>>;;

(* ------------------------------------------------------------------------- *)
(* Example from Manthey and Bry, CADE-9.                                     *)
(* ------------------------------------------------------------------------- *)

let p55 = time presolution
 <<lives(agatha) /\ lives(butler) /\ lives(charles) /\
   (killed(agatha,agatha) \/ killed(butler,agatha) \/
    killed(charles,agatha)) /\
   (forall x y. killed(x,y) ==> hates(x,y) /\ ~richer(x,y)) /\
   (forall x. hates(agatha,x) ==> ~hates(charles,x)) /\
   (hates(agatha,agatha) /\ hates(agatha,charles)) /\
   (forall x. lives(x) /\ ~richer(x,agatha) ==> hates(butler,x)) /\
   (forall x. hates(agatha,x) ==> hates(butler,x)) /\
   (forall x. ~hates(x,agatha) \/ ~hates(x,butler) \/ ~hates(x,charles))
   ==> killed(agatha,agatha) /\
       ~killed(butler,agatha) /\
       ~killed(charles,agatha)>>;;

(* ------------------------------------------------------------------------- *)
(* From Gilmore's classic paper.                                             *)
(* ------------------------------------------------------------------------- *)

let gilmore_1 = time presolution
 <<exists x. forall y z.
      ((F(y) ==> G(y)) <=> F(x)) /\
      ((F(y) ==> H(y)) <=> G(x)) /\
      (((F(y) ==> G(y)) ==> H(y)) <=> H(x))
      ==> F(z) /\ G(z) /\ H(z)>>;;

let gilmore_3 = time presolution
 <<exists x. forall y z.
        ((F(y,z) ==> (G(y) ==> H(x))) ==> F(x,x)) /\
        ((F(z,x) ==> G(x)) ==> H(z)) /\
        F(x,y)
        ==> F(z,z)>>;;

let gilmore_4 = time presolution
 <<exists x y. forall z.
        (F(x,y) ==> F(y,z) /\ F(z,z)) /\
        (F(x,y) /\ G(x,y) ==> G(x,z) /\ G(z,z))>>;;

let gilmore_5 = time presolution
 <<(forall x. exists y. F(x,y) \/ F(y,x)) /\
   (forall x y. F(y,x) ==> F(y,y))
   ==> exists z. F(z,z)>>;;

let gilmore_6 = time presolution
 <<forall x. exists y.
        (exists u. forall v. F(u,x) ==> G(v,u) /\ G(u,x))
        ==> (exists u. forall v. F(u,y) ==> G(v,u) /\ G(u,y)) \/
            (forall u v. exists w. G(v,u) \/ H(w,y,u) ==> G(u,w))>>;;

let gilmore_7 = time presolution
 <<(forall x. K(x) ==> exists y. L(y) /\ (F(x,y) ==> G(x,y))) /\
   (exists z. K(z) /\ forall u. L(u) ==> F(z,u))
   ==> exists v w. K(v) /\ L(w) /\ G(v,w)>>;;

let gilmore_8 = time presolution
 <<exists x. forall y z.
        ((F(y,z) ==> (G(y) ==> (forall u. exists v. H(u,v,x)))) ==> F(x,x)) /\
        ((F(z,x) ==> G(x)) ==> (forall u. exists v. H(u,v,z))) /\
        F(x,y)
        ==> F(z,z)>>;;

(* ------------------------------------------------------------------------- *)
(* Example from Davis-Putnam papers where Gilmore procedure is poor.         *)
(* ------------------------------------------------------------------------- *)

let davis_putnam_example = time presolution
 <<exists x. exists y. forall z.
        (F(x,y) ==> (F(y,z) /\ F(z,z))) /\
        ((F(x,y) /\ G(x,y)) ==> (G(x,z) /\ G(z,z)))>>;;

*)

(* ------------------------------------------------------------------------- *)
(* Introduce a set-of-support restriction.                                   *)
(* ------------------------------------------------------------------------- *)

let pure_resolution fm =
  let cls = filter (non tautologous) (clausal(specialize(pnf fm))) in
  resloop(partition (exists positive) cls);;

let resolution fm =
  let fm1 = askolemize(Not(generalize fm)) in
  map (pure_resolution ** list_conj) (simpdnf fm1);;

(* ------------------------------------------------------------------------- *)
(* Example                                                                   *)
(* ------------------------------------------------------------------------- *)

(*
let gilmore_1 = resolution
 <<exists x. forall y z.
      ((F(y) ==> G(y)) <=> F(x)) /\
      ((F(y) ==> H(y)) <=> G(x)) /\
      (((F(y) ==> G(y)) ==> H(y)) <=> H(x))
      ==> F(z) /\ G(z) /\ H(z)>>;;

(* ------------------------------------------------------------------------- *)
(* Some Pelletier problems.                                                  *)
(* ------------------------------------------------------------------------- *)

let p1 = time resolution
 <<p ==> q <=> ~q ==> ~p>>;;

let p2 = time resolution
 <<~ ~p <=> p>>;;

let p3 = time resolution
 <<~(p ==> q) ==> q ==> p>>;;

let p4 = time resolution
 <<~p ==> q <=> ~q ==> p>>;;

let p5 = time resolution
 <<(p \/ q ==> p \/ r) ==> p \/ (q ==> r)>>;;

let p6 = time resolution
 <<p \/ ~p>>;;

let p7 = time resolution
 <<p \/ ~ ~ ~p>>;;

let p8 = time resolution
 <<((p ==> q) ==> p) ==> p>>;;

let p9 = time resolution
 <<(p \/ q) /\ (~p \/ q) /\ (p \/ ~q) ==> ~(~q \/ ~q)>>;;

let p10 = time resolution
 <<(q ==> r) /\ (r ==> p /\ q) /\ (p ==> q /\ r) ==> (p <=> q)>>;;

let p11 = time resolution
 <<p <=> p>>;;

let p12 = time resolution
 <<((p <=> q) <=> r) <=> (p <=> (q <=> r))>>;;

let p13 = time resolution
 <<p \/ q /\ r <=> (p \/ q) /\ (p \/ r)>>;;

let p14 = time resolution
 <<(p <=> q) <=> (q \/ ~p) /\ (~q \/ p)>>;;

let p15 = time resolution
 <<p ==> q <=> ~p \/ q>>;;

let p16 = time resolution
 <<(p ==> q) \/ (q ==> p)>>;;

let p17 = time resolution
 <<p /\ (q ==> r) ==> s <=> (~p \/ q \/ s) /\ (~p \/ ~r \/ s)>>;;

let p18 = time resolution
 <<exists y. forall x. P(y) ==> P(x)>>;;

let p19 = time resolution
 <<exists x. forall y z. (P(y) ==> Q(z)) ==> P(x) ==> Q(x)>>;;

let p20 = time resolution
 <<(forall x y. exists z. forall w. P(x) /\ Q(y) ==> R(z) /\ U(w)) ==>
   (exists x y. P(x) /\ Q(y)) ==>
   (exists z. R(z))>>;;

let p21 = time resolution
 <<(exists x. P ==> Q(x)) /\ (exists x. Q(x) ==> P) ==> (exists x. P <=> Q(x))>>;;

let p22 = time resolution
 <<(forall x. P <=> Q(x)) ==> (P <=> (forall x. Q(x)))>>;;

let p23 = time resolution
 <<(forall x. P \/ Q(x)) <=> P \/ (forall x. Q(x))>>;;

let p24 = time resolution
 <<~(exists x. U(x) /\ Q(x)) /\
   (forall x. P(x) ==> Q(x) \/ R(x)) /\
   ~(exists x. P(x) ==> (exists x. Q(x))) /\
   (forall x. Q(x) /\ R(x) ==> U(x)) ==>
   (exists x. P(x) /\ R(x))>>;;

let p25 = time resolution
 <<(exists x. P(x)) /\
   (forall x. U(x) ==> ~G(x) /\ R(x)) /\
   (forall x. P(x) ==> G(x) /\ U(x)) /\
   ((forall x. P(x) ==> Q(x)) \/ (exists x. Q(x) /\ P(x))) ==>
   (exists x. Q(x) /\ P(x))>>;;

let p26 = time resolution
 <<((exists x. P(x)) <=> (exists x. Q(x))) /\
   (forall x y. P(x) /\ Q(y) ==> (R(x) <=> U(y))) ==>
   ((forall x. P(x) ==> R(x)) <=> (forall x. Q(x) ==> U(x)))>>;;

let p27 = time resolution
 <<(exists x. P(x) /\ ~Q(x)) /\
   (forall x. P(x) ==> R(x)) /\
   (forall x. U(x) /\ V(x) ==> P(x)) /\
   (exists x. R(x) /\ ~Q(x)) ==>
   (forall x. U(x) ==> ~R(x)) ==>
   (forall x. U(x) ==> ~V(x))>>;;

let p28 = time resolution
 <<(forall x. P(x) ==> (forall x. Q(x))) /\
   ((forall x. Q(x) \/ R(x)) ==> (exists x. Q(x) /\ R(x))) /\
   ((exists x. R(x)) ==> (forall x. L(x) ==> M(x))) ==>
   (forall x. P(x) /\ L(x) ==> M(x))>>;;

let p29 = time resolution
 <<(exists x. P(x)) /\ (exists x. G(x)) ==>
   ((forall x. P(x) ==> H(x)) /\ (forall x. G(x) ==> J(x)) <=>
    (forall x y. P(x) /\ G(y) ==> H(x) /\ J(y)))>>;;

let p30 = time resolution
 <<(forall x. P(x) \/ G(x) ==> ~H(x)) /\ (forall x. (G(x) ==> ~U(x)) ==>
     P(x) /\ H(x)) ==>
   (forall x. U(x))>>;;

let p31 = time resolution
 <<~(exists x. P(x) /\ (G(x) \/ H(x))) /\ (exists x. Q(x) /\ P(x)) /\
   (forall x. ~H(x) ==> J(x)) ==>
   (exists x. Q(x) /\ J(x))>>;;

let p32 = time resolution
 <<(forall x. P(x) /\ (G(x) \/ H(x)) ==> Q(x)) /\
   (forall x. Q(x) /\ H(x) ==> J(x)) /\
   (forall x. R(x) ==> H(x)) ==>
   (forall x. P(x) /\ R(x) ==> J(x))>>;;

let p33 = time resolution
 <<(forall x. P(a) /\ (P(x) ==> P(b)) ==> P(c)) <=>
   (forall x. P(a) ==> P(x) \/ P(c)) /\ (P(a) ==> P(b) ==> P(c))>>;;

let p34 = time resolution
 <<((exists x. forall y. P(x) <=> P(y)) <=>
   ((exists x. Q(x)) <=> (forall y. Q(y)))) <=>
   ((exists x. forall y. Q(x) <=> Q(y)) <=>
  ((exists x. P(x)) <=> (forall y. P(y))))>>;;

let p35 = time resolution
 <<exists x y. P(x,y) ==> (forall x y. P(x,y))>>;;

let p36 = time resolution
 <<(forall x. exists y. P(x,y)) /\
   (forall x. exists y. G(x,y)) /\
   (forall x y. P(x,y) \/ G(x,y)
   ==> (forall z. P(y,z) \/ G(y,z) ==> H(x,z)))
       ==> (forall x. exists y. H(x,y))>>;;

let p37 = time resolution
 <<(forall z.
     exists w. forall x. exists y. (P(x,z) ==> P(y,w)) /\ P(y,z) /\
     (P(y,w) ==> (exists u. Q(u,w)))) /\
   (forall x z. ~P(x,z) ==> (exists y. Q(y,z))) /\
   ((exists x y. Q(x,y)) ==> (forall x. R(x,x))) ==>
   (forall x. exists y. R(x,y))>>;;

let p39 = time resolution
 <<~(exists x. forall y. P(y,x) <=> ~P(y,y))>>;;

let p40 = time resolution
 <<(exists y. forall x. P(x,y) <=> P(x,x))
  ==> ~(forall x. exists y. forall z. P(z,y) <=> ~P(z,x))>>;;

let p41 = time resolution
 <<(forall z. exists y. forall x. P(x,y) <=> P(x,z) /\ ~P(x,x))
  ==> ~(exists z. forall x. P(x,z))>>;;

let p44 = time resolution
 <<(forall x. P(x) ==> (exists y. G(y) /\ H(x,y)) /\
   (exists y. G(y) /\ ~H(x,y))) /\
   (exists x. J(x) /\ (forall y. G(y) ==> H(x,y))) ==>
   (exists x. J(x) /\ ~P(x))>>;;

let p55 = time resolution
 <<lives(agatha) /\ lives(butler) /\ lives(charles) /\
   (killed(agatha,agatha) \/ killed(butler,agatha) \/
    killed(charles,agatha)) /\
   (forall x y. killed(x,y) ==> hates(x,y) /\ ~richer(x,y)) /\
   (forall x. hates(agatha,x) ==> ~hates(charles,x)) /\
   (hates(agatha,agatha) /\ hates(agatha,charles)) /\
   (forall x. lives(x) /\ ~richer(x,agatha) ==> hates(butler,x)) /\
   (forall x. hates(agatha,x) ==> hates(butler,x)) /\
   (forall x. ~hates(x,agatha) \/ ~hates(x,butler) \/ ~hates(x,charles))
   ==> killed(agatha,agatha) /\
       ~killed(butler,agatha) /\
       ~killed(charles,agatha)>>;;

let p57 = time resolution
 <<P(f((a),b),f(b,c)) /\
   P(f(b,c),f(a,c)) /\
   (forall (x) y z. P(x,y) /\ P(y,z) ==> P(x,z))
   ==> P(f(a,b),f(a,c))>>;;

*)
(* ========================================================================= *)
(* Backchaining procedure for Horn clauses, and toy Prolog implementation.   *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

(* ------------------------------------------------------------------------- *)
(* Rename a rule.                                                            *)
(* ------------------------------------------------------------------------- *)

let renamer k (asm,c) =
  let fvs = fv(list_conj(c::asm)) in
  let n = length fvs in
  let vvs = map (fun i -> "_" ^ string_of_int i) (k -- (k+n-1)) in
  let inst = formsubst(instantiate fvs (map (fun x -> Var x) vvs)) in
  (map inst asm,inst c),k+n;;

(* ------------------------------------------------------------------------- *)
(* Basic prover for Horn clauses based on backchaining with unification.     *)
(* ------------------------------------------------------------------------- *)

let rec backchain rules n k env goals =
  match goals with
    [] -> env
  | g::gs ->
     if n = 0 then failwith "Too deep" else
     tryfind (fun rule ->
        let (a,c),k' = renamer k rule in
        backchain rules (n - 1) k' (unify_literals env (c,g)) (a @ gs))
     rules;;

let hornify cls =
  let pos,neg = partition positive cls in
  if length pos > 1 then failwith "non-Horn clause"
  else (map negate neg,if pos = [] then False else hd pos);;

let hornprove fm =
  let rules = map hornify (clausal(skolemize(Not(generalize fm)))) in
  deepen (fun n -> backchain rules n 0 undefined [False],n) 0;;

(* ------------------------------------------------------------------------- *)
(* Some Horn examples.                                                       *)
(* ------------------------------------------------------------------------- *)

(*
let p1 = hornprove
 <<p ==> q <=> ~q ==> ~p>>;;

let p18 = hornprove
 <<exists y. forall x. P(y) ==> P(x)>>;;

let p32 = hornprove
 <<(forall x. P(x) /\ (G(x) \/ H(x)) ==> Q(x)) /\
   (forall x. Q(x) /\ H(x) ==> J(x)) /\
   (forall x. R(x) ==> H(x)) ==>
   (forall x. P(x) /\ R(x) ==> J(x))>>;;

*)

(* ------------------------------------------------------------------------- *)
(* Parsing rules in a Prolog-like syntax.                                    *)
(* ------------------------------------------------------------------------- *)

let parserule s =
  let c,rest = parse_formula parse_atom [] (lex(explode s)) in
  let asm,rest1 =
    if rest <> [] & hd rest = ":-"
    then parse_list "," (parse_formula parse_atom []) (tl rest)
    else [],rest in
  if rest1 = [] then (asm,c) else failwith "Extra material after rule";;

(* ------------------------------------------------------------------------- *)
(* Prolog interpreter: just use depth-first search not iterative deepening.  *)
(* ------------------------------------------------------------------------- *)

let simpleprolog rules gl =
  backchain (map parserule rules) (-1) 0 undefined [parse gl];;

(* ------------------------------------------------------------------------- *)
(* ML version of the first Prolog example.                                   *)
(* ------------------------------------------------------------------------- *)

type numeral = Z | S of numeral;;

let rec less_or_equal =
  function (Z,x) -> true
         | (S(x),S(y)) -> less_or_equal (x,y);;

(* ------------------------------------------------------------------------- *)
(* Ordering example.                                                         *)
(* ------------------------------------------------------------------------- *)

(*
let lerules = ["0 <= X"; "S(X) <= S(Y) :- X <= Y"];;

simpleprolog lerules "S(S(0)) <= S(S(S(0)))";;

let env = simpleprolog lerules "S(S(0)) <= X";;
apply env "X";;
*)

(* ------------------------------------------------------------------------- *)
(* With instantiation collection to produce a more readable result.          *)
(* ------------------------------------------------------------------------- *)

let prolog rules gl =
  let i = solve(simpleprolog rules gl) in
  mapfilter (fun x -> Atom(R("=",[Var x; apply i x]))) (fv(parse gl));;

(* ------------------------------------------------------------------------- *)
(* Example again.                                                            *)
(* ------------------------------------------------------------------------- *)

(*
prolog lerules "S(S(0)) <= X";;

(* ------------------------------------------------------------------------- *)
(* Append example, showing symmetry between inputs and outputs.              *)
(* ------------------------------------------------------------------------- *)

let appendrules =
  ["append(nil,L,L)"; "append(H::T,L,H::A) :- append(T,L,A)"];;

prolog appendrules "append(1::2::nil,3::4::nil,Z)";;

prolog appendrules "append(1::2::nil,Y,1::2::3::4::nil)";;

prolog appendrules "append(X,3::4::nil,1::2::3::4::nil)";;

prolog appendrules "append(X,Y,1::2::3::4::nil)";;

(* ------------------------------------------------------------------------- *)
(* A sorting example (from Lloyd's "Foundations of Logic Programming").      *)
(* ------------------------------------------------------------------------- *)

let sortrules =
 ["sort(X,Y) :- perm(X,Y),sorted(Y)";
  "sorted(nil)";
  "sorted(X::nil)";
  "sorted(X::Y::Z) :- X <= Y, sorted(Y::Z)";
  "perm(nil,nil)";
  "perm(X::Y,U::V) :- delete(U,X::Y,Z), perm(Z,V)";
  "delete(X,X::Y,Y)";
  "delete(X,Y::Z,Y::W) :- delete(X,Z,W)";
  "0 <= X";
  "S(X) <= S(Y) :- X <= Y"];;

prolog sortrules "sort(S(0)::0::nil,X)";;

prolog sortrules "sort(S(0)::S(S(0))::0::nil,X)";;

prolog sortrules
  "sort(S(S(S(S(0))))::S(0)::0::S(S(0))::S(0)::nil,X)";;

*)
(* ========================================================================= *)
(* Model elimination procedure (MESON version, based on Stickel's PTTP).     *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

(* ------------------------------------------------------------------------- *)
(* Example of naivety of tableau prover.                                     *)
(* ------------------------------------------------------------------------- *)

(*
tab <<forall a. ~(P(a) /\ (forall y z. Q(y) \/ R(z)) /\ ~P(a))>>;;

tab <<forall a. ~(P(a) /\ ~P(a) /\ (forall y z. Q(y) \/ R(z)))>>;;

(* ------------------------------------------------------------------------- *)
(* The interesting example where tableaux connections make the proof longer. *)
(* Unfortuntely this gets hammered by normalization first...                 *)
(* ------------------------------------------------------------------------- *)

let th = tab
 <<~p /\ (p \/ q) /\ (r \/ s) /\ (~q \/ t \/ u) /\
   (~r \/ ~t) /\ (~r \/ ~u) /\ (~q \/ v \/ w) /\
   (~s \/ ~v) /\ (~s \/ ~w) ==> false>>;;
*)

(* ------------------------------------------------------------------------- *)
(* Generation of contrapositives.                                            *)
(* ------------------------------------------------------------------------- *)

let contrapositives cls =
  let base = map (fun c -> map negate (subtract cls [c]),c) cls in
  if forall negative cls then (map negate cls,False)::base else base;;

(* ------------------------------------------------------------------------- *)
(* The core of MESON: ancestor unification or Prolog-style extension.        *)
(* ------------------------------------------------------------------------- *)

let rec expand rules ancestors g cont (env,n,k) =
  if n < 0 then failwith "Too deep" else
  try tryfind (fun a -> cont (unify_literals env (g,negate a),n,k))
              ancestors
  with Failure _ -> tryfind
    (fun rule -> let (asm,c),k' = renamer k rule in
                 itlist (expand rules (g::ancestors)) asm cont
                        (unify_literals env (g,c),n-length asm,k'))
    rules;;

(* ------------------------------------------------------------------------- *)
(* Full MESON procedure.                                                     *)
(* ------------------------------------------------------------------------- *)

let puremeson fm =
  let cls = clausal(specialize(pnf fm)) in
  let rules = itlist ((@) ** contrapositives) cls [] in
  deepen (fun n ->
     expand rules [] False (fun x -> x) (undefined,n,0); n) 0;;

let meson fm =
  let fm1 = askolemize(Not(generalize fm)) in
  map (puremeson ** list_conj) (simpdnf fm1);;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)

(*
let davis_putnam_example = meson
 <<exists x. exists y. forall z.
        (F(x,y) ==> (F(y,z) /\ F(z,z))) /\
        ((F(x,y) /\ G(x,y)) ==> (G(x,z) /\ G(z,z)))>>;;

let p38 = time meson
 <<(forall x.
     P(a) /\ (P(x) ==> (exists y. P(y) /\ R(x,y))) ==>
     (exists z w. P(z) /\ R(x,w) /\ R(w,z))) <=>
   (forall x.
     (~P(a) \/ P(x) \/ (exists z w. P(z) /\ R(x,w) /\ R(w,z))) /\
     (~P(a) \/ ~(exists y. P(y) /\ R(x,y)) \/
     (exists z w. P(z) /\ R(x,w) /\ R(w,z))))>>;;

let gilmore_9a = meson
 <<(forall x y. P(x,y) <=>
                forall u. exists v. F(x,u,v) /\ G(y,u) /\ ~H(x,y))
   ==> forall x. exists y. forall z.
             (P(y,x) ==> (P(x,z) ==> P(x,y))) /\
             (P(x,y) ==> (~P(x,z) ==> P(y,x) /\ P(z,y)))>>;;
*)

(* ------------------------------------------------------------------------- *)
(* With repetition checking and divide-and-conquer search.                   *)
(* ------------------------------------------------------------------------- *)

let rec equal env fm1 fm2 =
  try unify_literals env (fm1,fm2) == env with Failure _ -> false;;

let expand2 expfn goals1 n1 goals2 n2 n3 cont env k =
   expfn goals1 (fun (e1,r1,k1) ->
        expfn goals2 (fun (e2,r2,k2) ->
                        if n2 + r1 <= n3 + r2 then failwith "pair"
                        else cont(e2,r2,k2))
              (e1,n2+r1,k1))
        (env,n1,k);;

let rec expand rules ancestors g cont (env,n,k) =
  if n < 0 then failwith "Too deep"
  else if exists (equal env g) ancestors then failwith "repetition" else
  try tryfind (fun a -> cont (unify_literals env (g,negate a),n,k))
              ancestors
  with Failure _ -> tryfind
    (fun r -> let (asm,c),k' = renamer k r in
              expands rules (g::ancestors) asm cont
                      (unify_literals env (g,c),n-length asm,k'))
    rules

and expands rules ancestors gs cont (env,n,k) =
  if n < 0 then failwith "Too deep" else
  let m = length gs in
  if m <= 1 then itlist (expand rules ancestors) gs cont (env,n,k) else
  let n1 = n / 2 in
  let n2 = n - n1 in
  let goals1,goals2 = chop_list (m / 2) gs in
  let expfn = expand2 (expands rules ancestors) in
  try expfn goals1 n1 goals2 n2 (-1) cont env k
  with Failure _ -> expfn goals2 n1 goals1 n2 n1 cont env k;;

let puremeson fm =
  let cls = clausal(specialize(pnf fm)) in
  let rules = itlist ((@) ** contrapositives) cls [] in
  deepen (fun n ->
     expand rules [] False (fun x -> x) (undefined,n,0); n) 0;;

let meson fm =
  let fm1 = askolemize(Not(generalize fm)) in
  map (puremeson ** list_conj) (simpdnf fm1);;

(* ------------------------------------------------------------------------- *)
(* Test it on some of the Pelletiers.                                        *)
(* ------------------------------------------------------------------------- *)

(*

let prop_1 = time meson
 <<p ==> q <=> ~q ==> ~p>>;;

let prop_2 = time meson
 <<~ ~p <=> p>>;;

let prop_3 = time meson
 <<~(p ==> q) ==> q ==> p>>;;

let prop_4 = time meson
 <<~p ==> q <=> ~q ==> p>>;;

let prop_5 = time meson
 <<(p \/ q ==> p \/ r) ==> p \/ (q ==> r)>>;;

let prop_6 = time meson
 <<p \/ ~p>>;;

let prop_7 = time meson
 <<p \/ ~ ~ ~p>>;;

let prop_8 = time meson
 <<((p ==> q) ==> p) ==> p>>;;

let prop_9 = time meson
 <<(p \/ q) /\ (~p \/ q) /\ (p \/ ~q) ==> ~(~q \/ ~q)>>;;

let prop_10 = time meson
 <<(q ==> r) /\ (r ==> p /\ q) /\ (p ==> q /\ r) ==> (p <=> q)>>;;

let prop_11 = time meson
 <<p <=> p>>;;

let prop_12 = time meson
 <<((p <=> q) <=> r) <=> (p <=> (q <=> r))>>;;

let prop_13 = time meson
 <<p \/ q /\ r <=> (p \/ q) /\ (p \/ r)>>;;

let prop_14 = time meson
 <<(p <=> q) <=> (q \/ ~p) /\ (~q \/ p)>>;;

let prop_15 = time meson
 <<p ==> q <=> ~p \/ q>>;;

let prop_16 = time meson
 <<(p ==> q) \/ (q ==> p)>>;;

let prop_17 = time meson
 <<p /\ (q ==> r) ==> s <=> (~p \/ q \/ s) /\ (~p \/ ~r \/ s)>>;;

let p18 = time meson
 <<exists y. forall x. P(y) ==> P(x)>>;;

let p19 = time meson
 <<exists x. forall y z. (P(y) ==> Q(z)) ==> P(x) ==> Q(x)>>;;

let p20 = time meson
 <<(forall x y. exists z. forall w. P(x) /\ Q(y) ==> R(z) /\ U(w)) ==>
   (exists x y. P(x) /\ Q(y)) ==>
   (exists z. R(z))>>;;

let p21 = time meson
 <<(exists x. P ==> Q(x)) /\ (exists x. Q(x) ==> P)
   ==> (exists x. P <=> Q(x))>>;;

let p22 = time meson
 <<(forall x. P <=> Q(x)) ==> (P <=> (forall x. Q(x)))>>;;

let p23 = time meson
 <<(forall x. P \/ Q(x)) <=> P \/ (forall x. Q(x))>>;;

let p24 = time meson
 <<~(exists x. U(x) /\ Q(x)) /\
   (forall x. P(x) ==> Q(x) \/ R(x)) /\
   ~(exists x. P(x) ==> (exists x. Q(x))) /\
   (forall x. Q(x) /\ R(x) ==> U(x)) ==>
   (exists x. P(x) /\ R(x))>>;;

let p25 = time meson
 <<(exists x. P(x)) /\
   (forall x. U(x) ==> ~G(x) /\ R(x)) /\
   (forall x. P(x) ==> G(x) /\ U(x)) /\
   ((forall x. P(x) ==> Q(x)) \/ (exists x. Q(x) /\ P(x))) ==>
   (exists x. Q(x) /\ P(x))>>;;

let p26 = time meson
 <<((exists x. P(x)) <=> (exists x. Q(x))) /\
   (forall x y. P(x) /\ Q(y) ==> (R(x) <=> U(y))) ==>
   ((forall x. P(x) ==> R(x)) <=> (forall x. Q(x) ==> U(x)))>>;;

let p27 = time meson
 <<(exists x. P(x) /\ ~Q(x)) /\
   (forall x. P(x) ==> R(x)) /\
   (forall x. U(x) /\ V(x) ==> P(x)) /\
   (exists x. R(x) /\ ~Q(x)) ==>
   (forall x. U(x) ==> ~R(x)) ==>
   (forall x. U(x) ==> ~V(x))>>;;

let p28 = time meson
 <<(forall x. P(x) ==> (forall x. Q(x))) /\
   ((forall x. Q(x) \/ R(x)) ==> (exists x. Q(x) /\ R(x))) /\
   ((exists x. R(x)) ==> (forall x. L(x) ==> M(x))) ==>
   (forall x. P(x) /\ L(x) ==> M(x))>>;;

let p29 = time meson
 <<(exists x. P(x)) /\ (exists x. G(x)) ==>
   ((forall x. P(x) ==> H(x)) /\ (forall x. G(x) ==> J(x)) <=>
    (forall x y. P(x) /\ G(y) ==> H(x) /\ J(y)))>>;;

let p30 = time meson
 <<(forall x. P(x) \/ G(x) ==> ~H(x)) /\ (forall x. (G(x) ==> ~U(x)) ==>
     P(x) /\ H(x)) ==>
   (forall x. U(x))>>;;

let p31 = time meson
 <<~(exists x. P(x) /\ (G(x) \/ H(x))) /\ (exists x. Q(x) /\ P(x)) /\
   (forall x. ~H(x) ==> J(x)) ==>
   (exists x. Q(x) /\ J(x))>>;;

let p32 = time meson
 <<(forall x. P(x) /\ (G(x) \/ H(x)) ==> Q(x)) /\
   (forall x. Q(x) /\ H(x) ==> J(x)) /\
   (forall x. R(x) ==> H(x)) ==>
   (forall x. P(x) /\ R(x) ==> J(x))>>;;

let p33 = time meson
 <<(forall x. P(a) /\ (P(x) ==> P(b)) ==> P(c)) <=>
   (forall x. P(a) ==> P(x) \/ P(c)) /\ (P(a) ==> P(b) ==> P(c))>>;;

let p34 = time meson
 <<((exists x. forall y. P(x) <=> P(y)) <=>
   ((exists x. Q(x)) <=> (forall y. Q(y)))) <=>
   ((exists x. forall y. Q(x) <=> Q(y)) <=>
  ((exists x. P(x)) <=> (forall y. P(y))))>>;;

let p35 = time meson
 <<exists x y. P(x,y) ==> (forall x y. P(x,y))>>;;

let p36 = time meson
 <<(forall x. exists y. P(x,y)) /\
   (forall x. exists y. G(x,y)) /\
   (forall x y. P(x,y) \/ G(x,y)
   ==> (forall z. P(y,z) \/ G(y,z) ==> H(x,z)))
       ==> (forall x. exists y. H(x,y))>>;;

let p37 = time meson
 <<(forall z.
     exists w. forall x. exists y. (P(x,z) ==> P(y,w)) /\ P(y,z) /\
     (P(y,w) ==> (exists u. Q(u,w)))) /\
   (forall x z. ~P(x,z) ==> (exists y. Q(y,z))) /\
   ((exists x y. Q(x,y)) ==> (forall x. R(x,x))) ==>
   (forall x. exists y. R(x,y))>>;;

let p38 = time meson
 <<(forall x.
     P(a) /\ (P(x) ==> (exists y. P(y) /\ R(x,y))) ==>
     (exists z w. P(z) /\ R(x,w) /\ R(w,z))) <=>
   (forall x.
     (~P(a) \/ P(x) \/ (exists z w. P(z) /\ R(x,w) /\ R(w,z))) /\
     (~P(a) \/ ~(exists y. P(y) /\ R(x,y)) \/
     (exists z w. P(z) /\ R(x,w) /\ R(w,z))))>>;;

let p39 = time meson
 <<~(exists x. forall y. P(y,x) <=> ~P(y,y))>>;;

let p40 = time meson
 <<(exists y. forall x. P(x,y) <=> P(x,x))
  ==> ~(forall x. exists y. forall z. P(z,y) <=> ~P(z,x))>>;;

let p41 = time meson
 <<(forall z. exists y. forall x. P(x,y) <=> P(x,z) /\ ~P(x,x))
  ==> ~(exists z. forall x. P(x,z))>>;;

let p42 = time meson
 <<~(exists y. forall x. P(x,y) <=> ~(exists z. P(x,z) /\ P(z,x)))>>;;

let p43 = time meson
 <<(forall x y. Q(x,y) <=> forall z. P(z,x) <=> P(z,y))
   ==> forall x y. Q(x,y) <=> Q(y,x)>>;;

let p44 = time meson
 <<(forall x. P(x) ==> (exists y. G(y) /\ H(x,y)) /\
   (exists y. G(y) /\ ~H(x,y))) /\
   (exists x. J(x) /\ (forall y. G(y) ==> H(x,y))) ==>
   (exists x. J(x) /\ ~P(x))>>;;

let p45 = time meson
 <<(forall x.
     P(x) /\ (forall y. G(y) /\ H(x,y) ==> J(x,y)) ==>
       (forall y. G(y) /\ H(x,y) ==> R(y))) /\
   ~(exists y. L(y) /\ R(y)) /\
   (exists x. P(x) /\ (forall y. H(x,y) ==>
     L(y)) /\ (forall y. G(y) /\ H(x,y) ==> J(x,y))) ==>
   (exists x. P(x) /\ ~(exists y. G(y) /\ H(x,y)))>>;;

let p46 = time meson
 <<(forall x. P(x) /\ (forall y. P(y) /\ H(y,x) ==> G(y)) ==> G(x)) /\
   ((exists x. P(x) /\ ~G(x)) ==>
    (exists x. P(x) /\ ~G(x) /\
               (forall y. P(y) /\ ~G(y) ==> J(x,y)))) /\
   (forall x y. P(x) /\ P(y) /\ H(x,y) ==> ~J(y,x)) ==>
   (forall x. P(x) ==> G(x))>>;;

let p55 = time meson
 <<lives(agatha) /\ lives(butler) /\ lives(charles) /\
   (killed(agatha,agatha) \/ killed(butler,agatha) \/
    killed(charles,agatha)) /\
   (forall x y. killed(x,y) ==> hates(x,y) /\ ~richer(x,y)) /\
   (forall x. hates(agatha,x) ==> ~hates(charles,x)) /\
   (hates(agatha,agatha) /\ hates(agatha,charles)) /\
   (forall x. lives(x) /\ ~richer(x,agatha) ==> hates(butler,x)) /\
   (forall x. hates(agatha,x) ==> hates(butler,x)) /\
   (forall x. ~hates(x,agatha) \/ ~hates(x,butler) \/ ~hates(x,charles))
   ==> killed(agatha,agatha) /\
       ~killed(butler,agatha) /\
       ~killed(charles,agatha)>>;;

let p57 = time meson
 <<P(f((a),b),f(b,c)) /\
  P(f(b,c),f(a,c)) /\
  (forall (x) y z. P(x,y) /\ P(y,z) ==> P(x,z))
  ==> P(f(a,b),f(a,c))>>;;

(* ------------------------------------------------------------------------- *)
(* Translation of Gilmore procedure using separate definitions.              *)
(* ------------------------------------------------------------------------- *)

let gilmore_9a = time meson
 <<(forall x y. P(x,y) <=>
                forall u. exists v. F(x,u,v) /\ G(y,u) /\ ~H(x,y))
   ==> forall x. exists y. forall z.
             (P(y,x) ==> (P(x,z) ==> P(x,y))) /\
             (P(x,y) ==> (~P(x,z) ==> P(y,x) /\ P(z,y)))>>;;

(* ------------------------------------------------------------------------- *)
(* Example from Davis-Putnam papers where Gilmore procedure is poor.         *)
(* ------------------------------------------------------------------------- *)

let davis_putnam_example = time meson
 <<exists x. exists y. forall z.
        (F(x,y) ==> (F(y,z) /\ F(z,z))) /\
        ((F(x,y) /\ G(x,y)) ==> (G(x,z) /\ G(z,z)))>>;;

(* ------------------------------------------------------------------------- *)
(* The "connections make things worse" example once again.                   *)
(* ------------------------------------------------------------------------- *)

let th = meson
 <<~p /\ (p \/ q) /\ (r \/ s) /\ (~q \/ t \/ u) /\
   (~r \/ ~t) /\ (~r \/ ~u) /\ (~q \/ v \/ w) /\
   (~s \/ ~v) /\ (~s \/ ~w) ==> false>>;;

*)
(* ========================================================================= *)
(* Illustration of Skolemizing a set of formulas                             *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

let rec rename_term tm =
  match tm with
    Fn(f,args) -> Fn("old_"^f,map rename_term args)
  | _ -> tm;;

let rename_form fm =
  onatoms (fun (R(p,args)) -> Atom(R(p,map rename_term args))) fm;;

let rec skolems fms corr =
  match fms with
    [] -> [],corr
  | (p::ofms) ->
        let p',corr' = skolem (rename_form p) corr in
        let ps',corr'' = skolems ofms corr' in
        p'::ps',corr'';;

let skolemizes fms = fst(skolems fms []);;

(*
skolemizes [<<exists x y. x + y = 2>>;
            <<forall x. exists y. x + 1 = y>>];;
*)
(* ========================================================================= *)
(* First order logic with equality.                                          *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

let mk_eq s t = Atom(R("=",[s;t]));;

let dest_eq =
  function (Atom(R("=",[s;t]))) -> s,t
         | _ -> failwith "dest_eq: not an equation";;

(* ------------------------------------------------------------------------- *)
(* The set of predicates in a formula.                                       *)
(* ------------------------------------------------------------------------- *)

let rec predicates fm = atom_union (fun (R(p,a)) -> [p,length a]) fm;;

(* ------------------------------------------------------------------------- *)
(* Code to generate equality axioms for functions.                           *)
(* ------------------------------------------------------------------------- *)

let function_congruence (f,n) =
  if n = 0 then [] else
  let argnames_x = map (fun n -> "x"^(string_of_int n)) (1 -- n)
  and argnames_y = map (fun n -> "y"^(string_of_int n)) (1 -- n) in
  let args_x = map (fun x -> Var x) argnames_x
  and args_y = map (fun x -> Var x) argnames_y in
  let hyps = map2 (fun x y -> Atom(R("=",[x;y]))) args_x args_y in
  let ant = end_itlist (fun p q -> And(p,q)) hyps
  and con = Atom(R("=",[Fn(f,args_x); Fn(f,args_y)])) in
  [itlist (fun x p -> Forall(x,p)) (argnames_x @ argnames_y)
          (Imp(ant,con))];;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)

(*
function_congruence ("f",3);;

function_congruence ("+",2);;
*)

(* ------------------------------------------------------------------------- *)
(* And for predicates.                                                       *)
(* ------------------------------------------------------------------------- *)

let predicate_congruence (p,n) =
  if n = 0 then [] else
  let argnames_x = map (fun n -> "x"^(string_of_int n)) (1 -- n)
  and argnames_y = map (fun n -> "y"^(string_of_int n)) (1 -- n) in
  let args_x = map (fun x -> Var x) argnames_x
  and args_y = map (fun x -> Var x) argnames_y in
  let hyps = map2 (fun x y -> Atom(R("=",[x;y]))) args_x args_y in
  let ant = end_itlist (fun p q -> And(p,q)) hyps
  and con = Imp(Atom(R(p,args_x)),Atom(R(p,args_y))) in
  [itlist (fun x p -> Forall(x,p)) (argnames_x @ argnames_y)
          (Imp(ant,con))];;

(* ------------------------------------------------------------------------- *)
(* Hence implement logic with equality just by adding equality "axioms".     *)
(* ------------------------------------------------------------------------- *)

let equivalence_axioms =
  setify [<<forall x. x = x>>;
          <<forall x y z. x = y /\ x = z ==> y = z>>];;

let equalitize fm =
  let allpreds = predicates fm in
  if not (mem ("=",2) allpreds) then fm else
  let preds = subtract allpreds ["=",2]
  and funcs = functions fm in
  let axioms =
    itlist (union ** function_congruence) funcs
           (itlist (union ** predicate_congruence) preds
                   equivalence_axioms) in
  Imp(end_itlist (fun p q -> And(p,q)) axioms,fm);;

(* ------------------------------------------------------------------------- *)
(* A simple example (see EWD1266a and the application to Morley's theorem).  *)
(* ------------------------------------------------------------------------- *)

(*

let ewd = equalitize
 <<(forall x. f(x) ==> g(x)) /\
   (exists x. f(x)) /\
   (forall x y. g(x) /\ g(y) ==> x = y)
   ==> forall y. g(y) ==> f(y)>>;;

meson ewd;;

resolution ewd;;

splittab ewd;;

(* ------------------------------------------------------------------------- *)
(* Wishnu Prasetya's example (even nicer with an "exists unique" primitive). *)
(* ------------------------------------------------------------------------- *)

let wishnu = equalitize
 <<(exists x. x = f(g(x)) /\ forall x'. x' = f(g(x')) ==> x = x') <=>
   (exists y. y = g(f(y)) /\ forall y'. y' = g(f(y')) ==> y = y')>>;;

time meson wishnu;;

(* ------------------------------------------------------------------------- *)
(* An incestuous example used to establish completeness characterization.    *)
(* ------------------------------------------------------------------------- *)

meson
 <<(forall M p. sentence(p) ==> holds(M,p) \/ holds(M,not(p))) /\
   (forall M p. ~(holds(M,p) /\ holds(M,not(p))))
   ==> ((forall p. sentence(p)
                   ==> (forall M. models(M,S) ==> holds(M,p)) \/
                       (forall M. models(M,S) ==> holds(M,not(p)))) <=>
        (forall M M'. models(M,S) /\ models(M',S)
                      ==> forall p. sentence(p)
                                    ==> (holds(M,p) <=> holds(M',p))))>>;;

(* ------------------------------------------------------------------------- *)
(* Showing congruence closure.                                               *)
(* ------------------------------------------------------------------------- *)

let fm = equalitize
 <<forall c. f(f(f(f(f(c))))) = c /\ f(f(f(c))) = c ==> f(c) = c>>;;

time meson fm;;

*)
(* ========================================================================= *)
(* Simple congruence closure.                                                *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

(* ------------------------------------------------------------------------- *)
(* Test whether subterms are congruent under an equivalence.                 *)
(* ------------------------------------------------------------------------- *)

let congruent eqv =
  function (Fn(f,a1),Fn(g,a2)) ->
             f = g &
             forall2 (fun s t -> canonize eqv s = canonize eqv t) a1 a2
         | _ -> false;;

(* ------------------------------------------------------------------------- *)
(* Merging of terms, with congruence closure.                                *)
(* ------------------------------------------------------------------------- *)

let rec emerge (s,t) (eqv,pfn) =
  let s' = canonize eqv s and t' = canonize eqv t in
  if s' = t' then (eqv,pfn) else
  let sp = tryapplyl pfn s' and tp = tryapplyl pfn t' in
  let eqv' = equate (s,t) eqv in
  let st' = canonize eqv' s' in
  let pfn' = (st' |-> union sp tp) pfn in
  itlist (fun (u,v) (eqv,pfn) ->
             if congruent eqv (u,v) then emerge (u,v) (eqv,pfn)
             else eqv,pfn)
         (allpairs (fun u v -> (u,v)) sp tp) (eqv',pfn');;

(* ------------------------------------------------------------------------- *)
(* Useful auxiliary functions.                                               *)
(* ------------------------------------------------------------------------- *)

let rec subterms tm acc =
  match tm with
    Var x -> tm::acc
  | Fn(f,args) -> tm::(itlist subterms args acc);;

let successors = function (Fn(f,args)) -> setify args | _ -> [];;

(* ------------------------------------------------------------------------- *)
(* Satisfiability of conjunction of ground equations and inequations.        *)
(* ------------------------------------------------------------------------- *)

let ccsatisfiable fms =
  let pos,neg = partition positive fms in
  let eqps = map dest_eq pos and eqns = map (dest_eq ** negate) neg in
  let lrs = map fst eqps @ map snd eqps @ map fst eqns @ map snd eqns in
  let tms = setify (itlist subterms lrs []) in
  let pfn = itlist
   (fun x -> itlist (fun y f -> (y |-> insert x (tryapplyl f y)) f)
                    (successors x)) tms undefined in
  let eqv,_ = itlist emerge eqps (unequal,pfn) in
  forall (fun (l,r) -> canonize eqv l <> canonize eqv r) eqns;;

(* ------------------------------------------------------------------------- *)
(* Convert uninterpreted predicates into functions.                          *)
(* ------------------------------------------------------------------------- *)

let atomize fm =
  let preds = predicates fm and funs = functions fm in
  let n = Int 1 +/ itlist (max_varindex "P" ** fst) funs (Int 0) in
  let preds' = map (fun i -> "P_"^string_of_num i)
                   (n---(n+/Int(length preds))) in
  let alist = zip preds (butlast preds') and tr = Fn(last preds',[]) in
  let equalize(R(p,args) as at) =
    if p = "=" & length args = 2 then Atom at else
    Atom(R("=",[Fn(assoc (p,length args) alist,args); tr])) in
  onatoms equalize fm;;

(* ------------------------------------------------------------------------- *)
(* Validity checking a universal formula (this theory is trivially convex).  *)
(* ------------------------------------------------------------------------- *)

let ccvalid fm =
  let fms = simpdnf(askolemize(Not(generalize(atomize fm)))) in
  not (exists ccsatisfiable fms);;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)

(*
let fm =
 <<f(f(f(f(f(c))))) = c /\ f(f(f(c))) = c
   ==> f(c) = c \/ f(g(c)) = g(f(c))>> in
ccvalid fm;;

let fm = <<f(f(f(f(c)))) = c /\ f(f(c)) = c ==> f(c) = c>> in
ccvalid fm;;

let fm =
 <<f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(c))))))))))))))) = c /\
  f(f(f(f(c)))) = c
  ==> f(c) = c>> in
ccvalid fm;;

let fm =
 <<f(f(f(f(f(c))))) = c /\ f(f(f(c))) = c /\ (P(c) <=> ~Q(f(c)))
   ==> P(f(c)) \/ Q(c)>> in
ccvalid fm;;

*)
(* ========================================================================= *)
(* Rewriting.                                                                *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

(* ------------------------------------------------------------------------- *)
(* Matching of terms.                                                        *)
(* ------------------------------------------------------------------------- *)

let rec tmatch (vtm,ctm) env =
  match (vtm,ctm) with
    (Var x,t) ->
        if not (defined env x) then (x |-> t) env
        else if apply env x = t then env else failwith "tmatch"
  | Fn(f,fargs),Fn(g,gargs) ->
        if f = g then itlist tmatch (zip fargs gargs) env
        else failwith "tmatch"
  | _ -> failwith "tmatch";;

let term_match vtm ctm = tmatch(vtm,ctm) undefined;;

(* ------------------------------------------------------------------------- *)
(* Rewriting with a single equation.                                         *)
(* ------------------------------------------------------------------------- *)

let rewrite1 eq t =
  match eq with
    Atom(R("=",[l;r])) -> termsubst (term_match l t) r
  | _ -> failwith "rewrite1";;

(* ------------------------------------------------------------------------- *)
(* Rewriting with first in a list of equations.                              *)
(* ------------------------------------------------------------------------- *)

let rewrite eqs tm = tryfind (fun eq -> rewrite1 eq tm) eqs;;

(* ------------------------------------------------------------------------- *)
(* Applying a term transformation at depth.                                  *)
(* ------------------------------------------------------------------------- *)

let rec depth fn tm =
  try depth fn (fn tm) with Failure _ ->
  match tm with
    Var x -> tm
  | Fn(f,args) -> let tm' = Fn(f,map (depth fn) args) in
                  if tm' = tm then tm' else depth fn tm';;

(* ------------------------------------------------------------------------- *)
(* Example: 3 * 2 + 4 in successor notation.                                 *)
(* ------------------------------------------------------------------------- *)

(*
let eqs =
 [<<0 + x = x>>; <<S(x) + y = S(x + y)>>; <<x + S(y) = S(x + y)>>;
  <<0 * x = 0>>; <<S(x) * y = y + x * y>>];;

depth (rewrite eqs) <<|S(S(S(0))) * S(S(0)) + S(S(S(S(0))))|>>;;

(* ------------------------------------------------------------------------- *)
(* Combinatory logic.                                                        *)
(* ------------------------------------------------------------------------- *)

let eqs =
 [<<((S * f) * g) * x = (f * x) * (g * x)>>;
  <<(K * x) * y = x>>];;

depth (rewrite eqs) <<|((S * K) * K) * x|>>;;

(* ------------------------------------------------------------------------- *)
(* The 3x + 1 problem (Collatz conjecture).                                  *)
(* ------------------------------------------------------------------------- *)

let eqs =
  [<<1 = S(0)>>;
   <<2 = S(1)>>;
   <<3 = S(2)>>;
   <<0 + x = x>>;
   <<S(x) + y = S(x + y)>>;
   <<0 * y = 0>>;
   <<S(x) * y = y + x * y>>;
   <<run(S(S(x)),y) = run(x,S(y))>>;
   <<run(S(0),S(y)) = run(3 * (2 * y + 1) + 1,0)>>;
   <<run(0,S(y)) = run(S(y),0)>>;
   <<run(S(0),0) = one>>;
   <<run(0,0) = zero>>];;

(* ------------------------------------------------------------------------- *)
(* The calamitously inefficient example.                                     *)
(* ------------------------------------------------------------------------- *)

let eqs =
 [<<S(x) + (y + z) = x + (S(S(y)) + z)>>;
  <<S(u) + (v + (w + x)) = u + (w + (v + x))>>];;

depth (rewrite eqs) <<|S(x) + (y + z)|>>;;

depth (rewrite eqs) <<|S(a) + b + c + d + e + f + g + h|>>;;

depth (rewrite eqs) <<|S(a) + b + c + d|>>;;

depth (rewrite eqs) <<|S(a) + b + c + d + e|>>;;

depth (rewrite eqs) <<|S(a) + b + c + d + e + f|>>;;

depth (rewrite eqs) <<|S(a) + S(b) + S(c) + S(d) + S(e) + S(f)|>>;;

let eqs =
 [<<S(u) + (v + (w + x)) = u + (w + (v + x))>>;
  <<S(x) + (y + z) = x + (S(S(y)) + z)>>];;

depth (rewrite eqs) <<|S(a) + S(b) + S(c) + S(d) + S(e) + S(f)|>>;;

depth (rewrite eqs) <<|S(a) + b + c + d + e + f|>>;;

depth (rewrite eqs) <<|(S(a) + S(b)) + (S(c) + S(d)) + (S(e) + S(f))|>>;;

depth (rewrite eqs)
  <<|S(0) + (S(a) + S(b)) + (S(c) + S(d)) + (S(e) + S(f))|>>;;

*)
(* ========================================================================= *)
(* Term orderings.                                                           *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

let rec termsize tm =
  match tm with
    Var x -> 1
  | Fn(f,args) -> itlist (fun t n -> termsize t + n) args 1;;

(* ------------------------------------------------------------------------- *)
(* This fails the rewrite properties.                                        *)
(* ------------------------------------------------------------------------- *)

(*
let s = <<|f(x,x,x)|>> and t = <<|g(x,y)|>>;;

termsize s > termsize t;;

let i = ("y" := <<|f(x,x,x)|>>);;

termsize (termsubst i s) > termsize (termsubst i t);;
*)

(* ------------------------------------------------------------------------- *)
(* However we can do better with the following.                              *)
(* ------------------------------------------------------------------------- *)

let rec occurrences x tm =
  match tm with
    Var y -> if x = y then 1 else 0
  | Fn(f,args) -> itlist (fun t n -> occurrences x t + n) args 0;;

(* ------------------------------------------------------------------------- *)
(* Lexicographic path order.                                                 *)
(* ------------------------------------------------------------------------- *)

let rec lexord ord l1 l2 =
  match (l1,l2) with
    (h1::t1,h2::t2) -> if ord h1 h2 then length t1 = length t2
                       else h1 = h2 & lexord ord t1 t2
  | _ -> false;;

let rec lpo_gt w s t =
  match (s,t) with
    (_,Var x) ->
        not(s = t) & mem x (fvt s)
  | (Fn(f,fargs),Fn(g,gargs)) ->
        exists (fun si -> lpo_ge w si t) fargs or
        forall (lpo_gt w s) gargs &
        (f = g & lexord (lpo_gt w) fargs gargs or
         w (f,length fargs) (g,length gargs))
  | _ -> false

and lpo_ge w s t = (s = t) or lpo_gt w s t;;

(* ------------------------------------------------------------------------- *)
(* More convenient way of specifying weightings.                             *)
(* ------------------------------------------------------------------------- *)

let weight lis (f,n) (g,m) =
  let i = index f lis and j = index g lis in
  i > j or i = j & n > m;;
(* ========================================================================= *)
(* Knuth-Bendix completion.                                                  *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

let renamepair (fm1,fm2) =
  let fvs1 = fv fm1
  and fvs2 = fv fm2 in
  let nms1,nms2 =
     chop_list(length fvs1)
              (map (fun n -> Var("x"^string_of_int n))
                   (0 -- (length fvs1 + length fvs2 - 1))) in
  formsubst (itlist2 (fun x t -> x |-> t) fvs1 nms1 undefined) fm1,
  formsubst (itlist2 (fun x t -> x |-> t) fvs2 nms2 undefined) fm2;;

(* ------------------------------------------------------------------------- *)
(* Rewrite (using unification) with l = r inside tm to give a critical pair. *)
(* ------------------------------------------------------------------------- *)

let rec listcases fn rfn lis acc =
  match lis with
    [] -> acc
  | h::t -> fn h (fun i h' -> rfn i (h'::t)) @
            listcases fn (fun i t' -> rfn i (h::t')) t acc;;

let rec overlaps (l,r) tm rfn =
  match tm with
    Fn(f,args) ->
        listcases (overlaps (l,r)) (fun i a -> rfn i (Fn(f,a))) args
                  (try [rfn (fullunify [l,tm]) r] with Failure _ -> [])
  | Var x -> [];;

(* ------------------------------------------------------------------------- *)
(* Generate all critical pairs between two equations.                        *)
(* ------------------------------------------------------------------------- *)

let crit1 (Atom(R("=",[l1;r1]))) (Atom(R("=",[l2;r2]))) =
  overlaps (l1,r1) l2 (fun i t -> formsubst i (Atom(R("=",[t;r2]))));;

let critical_pairs fma fmb =
  let fm1,fm2 = renamepair (fma,fmb) in
  if fma = fmb then crit1 fm1 fm2
  else union (crit1 fm1 fm2) (crit1 fm2 fm1);;

(* ------------------------------------------------------------------------- *)
(* Simple example.                                                           *)
(* ------------------------------------------------------------------------- *)

let eq = <<f(f(x)) = g(x)>>;;
(*
critical_pairs eq eq;;
*)

(* ------------------------------------------------------------------------- *)
(* Orienting an equation.                                                    *)
(* ------------------------------------------------------------------------- *)

let normalize_and_orient ord eqs =
  fun (Atom(R("=",[s;t]))) ->
    let s' = depth(rewrite eqs) s and t' = depth(rewrite eqs) t in
    if ord s' t' then (s',t') else if ord t' s' then (t',s')
    else failwith "Can't orient equation";;

(* ------------------------------------------------------------------------- *)
(* Status report so the user doesn't get too bored.                          *)
(* ------------------------------------------------------------------------- *)

let status(eqs,def,crs) eqs0 =
  if eqs = eqs0 & (length crs) mod 1000 <> 0 then () else
  (print_string(string_of_int(length eqs)^" equations and "^
                string_of_int(length crs)^" pending critical pairs + "^
                string_of_int(length def)^" deferred");
   print_newline());;

(* ------------------------------------------------------------------------- *)
(* Completion main loop (deferring non-orientable equations).                *)
(* ------------------------------------------------------------------------- *)

let rec complete ord (eqs,def,crits) =
  match crits with
    (eq::ocrits) ->
        let trip =
          try let (s',t') = normalize_and_orient ord eqs eq in
              if s' = t' then (eqs,def,ocrits) else
              let eq' = Atom(R("=",[s';t'])) in
              let eqs' = eq'::eqs in
              eqs',def,
              ocrits @ itlist ((@) ** critical_pairs eq') eqs' []
          with Failure _ -> (eqs,eq::def,ocrits) in
        status trip eqs; complete ord trip
  | _ -> if def = [] then eqs else
         let e = find (can (normalize_and_orient ord eqs)) def in
         complete ord (eqs,subtract def [e],[e]);;

(* ------------------------------------------------------------------------- *)
(* A simple "manual" example, before considering packaging and refinements.  *)
(* ------------------------------------------------------------------------- *)

let eqs =
 [<<1 * x = x>>; <<i(x) * x = 1>>; <<(x * y) * z = x * y * z>>];;

let ord = lpo_ge (weight ["1"; "*"; "i"]);;

(*
let eqs' = complete ord
  (eqs,[],unions(allpairs critical_pairs eqs eqs));;

let tm = <<|i(x * i(x)) * (i(i((y * z) * u) * y) * i(u))|>>;;

depth(rewrite eqs') tm;;
*)

(* ------------------------------------------------------------------------- *)
(* Show that we get a significant difference just from changing order.       *)
(* ------------------------------------------------------------------------- *)

(*
let eqs =
 [<<(x * y) * z = x * y * z>>; <<1 * x = x>>; <<i(x) * x = 1>>];;

let eqs'' = complete ord
  (eqs,[],unions(allpairs critical_pairs eqs eqs));;
*)

(* ------------------------------------------------------------------------- *)
(* Interreduction.                                                           *)
(* ------------------------------------------------------------------------- *)

let rec interreduce ord dun eqs =
  match eqs with
    (Atom(R("=",[l;r])))::oeqs ->
        let rewr_fn = depth(rewrite (dun @ oeqs)) in
        if rewr_fn l <> l then interreduce ord dun oeqs
        else interreduce ord (Atom(R("=",[l;rewr_fn r]))::dun) oeqs
  | [] -> rev dun
  | _ -> failwith "non-equational input";;

(* ------------------------------------------------------------------------- *)
(* This does indeed help a lot.                                              *)
(* ------------------------------------------------------------------------- *)

(*
interreduce ord [] eqs';;
*)

(* ------------------------------------------------------------------------- *)
(* Overall function with post-simplification (but not dynamically).          *)
(* ------------------------------------------------------------------------- *)

let complete_and_simplify wts eqs =
  let ord = lpo_ge (weight wts) in
  if exists (fun (Atom(R("=",[l;r]))) -> ord r l) eqs
  then failwith "Initial equations not ordered by given ordering" else
    (interreduce ord [] ** complete ord)
    (eqs,[],unions(allpairs critical_pairs eqs eqs));;

(* ------------------------------------------------------------------------- *)
(* Central groupoids (K&B example 6).                                        *)
(* ------------------------------------------------------------------------- *)

(*
let eqs =  [<<(a * b) * (b * c) = b>>];;

complete_and_simplify ["*"] eqs;;

(* ------------------------------------------------------------------------- *)
(* Inverse property (K&B example 4).                                         *)
(* ------------------------------------------------------------------------- *)

let eqs =  [<<i(a) * (a * b) = b>>];;

complete_and_simplify ["1"; "*"; "i"] eqs;;

(* ------------------------------------------------------------------------- *)
(* Auxiliary result used to justify extension for example 9.                 *)
(* ------------------------------------------------------------------------- *)

(meson ** equalitize)
 <<(forall x y z. x * y = x * z ==> y = z) <=>
   (forall x z. exists w. forall y. z = x * y ==> w = y)>>;;

skolemize <<forall x z. exists w. forall y. z = x * y ==> w = y>>;;

let eqs =
  [<<f(a,a*b) = b>>; <<g(a*b,b) = a>>; <<1 * a = a>>; <<a * 1 = a>>];;

complete_and_simplify ["1"; "*"; "f"; "g"] eqs;;

(* ------------------------------------------------------------------------- *)
(* K&B example 7, where we need to divide through.                           *)
(* ------------------------------------------------------------------------- *)

let eqs =  [<<f(a,f(b,c,a),d) = c>>; <<f(a,b,c) = g(a,b)>>;
                     <<g(a,b) = h(b)>>];;

complete_and_simplify ["h"; "g"; "f"] eqs;;

(* ------------------------------------------------------------------------- *)
(* Group theory I (K & B example 1).                                         *)
(* ------------------------------------------------------------------------- *)

let eqs =
 [<<1 * x = x>>; <<i(x) * x = 1>>; <<(x * y) * z = x * y * z>>];;

complete_and_simplify ["1"; "*"; "i"] eqs;;

(* ------------------------------------------------------------------------- *)
(* Inverse property (K&B example 4).                                         *)
(* ------------------------------------------------------------------------- *)

let eqs =  [<<i(a) * (a * b) = b>>];;

complete_and_simplify ["1"; "*"; "i"] eqs;;

let eqs =  [<<a * (i(a) * b) = b>>];;

complete_and_simplify ["1"; "*"; "i"] eqs;;

(* ------------------------------------------------------------------------- *)
(* The cancellation law (K&B example 9).                                     *)
(* ------------------------------------------------------------------------- *)

let eqs =  [<<f(a,a*b) = b>>; <<g(a*b,b) = a>>];;

complete_and_simplify ["*"; "f"; "g"] eqs;;

let eqs =
  [<<f(a,a*b) = b>>; <<g(a*b,b) = a>>; <<1 * a = a>>; <<a * 1 = a>>];;

complete_and_simplify ["1"; "*"; "f"; "g"] eqs;;

(* ------------------------------------------------------------------------- *)
(* Loops (K&B example 10).                                                   *)
(* ------------------------------------------------------------------------- *)

let eqs =
 [<<a * \(a,b) = b>>; <</(a,b) * b = a>>; <<1 * a = a>>; <<a * 1 = a>>];;

complete_and_simplify ["1"; "*"; "\\"; "/"] eqs;;

let eqs =
 [<<a * \(a,b) = b>>; <</(a,b) * b = a>>; <<1 * a = a>>; <<a * 1 = a>>;
  <<f(a,a*b) = b>>; <<g(a*b,b) = a>>];;

complete_and_simplify ["1"; "*"; "\\"; "/"; "f"; "g"] eqs;;

(* ------------------------------------------------------------------------- *)
(* (r,l)-systems (K&B example 13).                                           *)
(* ------------------------------------------------------------------------- *)

let eqs =
 [<<(x * y) * z = x * y * z>>; <<x * 1 = x>>; <<i(x) * x = 1>>];;

complete_and_simplify ["1"; "*"; "i"] eqs;;

(* ------------------------------------------------------------------------- *)
(* Central groupoids II. (K&B example 16).                                   *)
(* ------------------------------------------------------------------------- *)

let eqs =
 [<<(a * a) * a = one(a)>>;
  <<a * (a * a) = two(a)>>;
  <<(a * b) * (b * c) = b>>;
  <<two(a) * b = a * b>>];;

complete_and_simplify ["one"; "two"; "*"] eqs;;

(* ------------------------------------------------------------------------- *)
(* Simply congruence closure.                                                *)
(* ------------------------------------------------------------------------- *)

let eqs =  [<<f(f(f(f(f(1))))) = 1>>; <<f(f(f(1))) = 1>>];;

complete_and_simplify ["1"; "f"] eqs;;

(* ------------------------------------------------------------------------- *)
(* A rather simple example from Baader & Nipkow, p. 141.                     *)
(* ------------------------------------------------------------------------- *)

let eqs =  [<<f(f(x)) = g(x)>>];;

complete_and_simplify ["g"; "f"] eqs;;

(* ------------------------------------------------------------------------- *)
(* Some of the exercises (these are taken from Baader & Nipkow).             *)
(* ------------------------------------------------------------------------- *)

let eqs =
 [<<f(f(x)) = f(x)>>;
  <<g(g(x)) = f(x)>>;
  <<f(g(x)) = g(x)>>;
  <<g(f(x)) = f(x)>>];;

complete_and_simplify ["f"; "g"] eqs;;

let eqs =  [<<f(g(f(x))) = g(x)>>];;

complete_and_simplify ["f"; "g"] eqs;;

(* ------------------------------------------------------------------------- *)
(* Inductive theorem proving example.                                        *)
(* ------------------------------------------------------------------------- *)

let eqs =
 [<<0 + y = y>>;
  <<SUC(x) + y = SUC(x + y)>>;
  <<append(nil,l) = l>>;
  <<append(h::t,l) = h::append(t,l)>>;
  <<length(nil) = 0>>;
  <<length(h::t) = SUC(length(t))>>;
  <<rev(nil) = nil>>;
  <<rev(h::t) = append(rev(t),h::nil)>>];;

complete_and_simplify
   ["0"; "nil"; "SUC"; "::"; "+"; "length"; "append"; "rev"] eqs;;

let iprove eqs' tm =
 complete_and_simplify
   ["0"; "nil"; "SUC"; "::"; "+"; "append"; "rev"; "length"]
   (tm :: eqs' @ eqs);;

iprove [] <<x + 0 = x>>;;

iprove [] <<x + SUC(y) = SUC(x + y)>>;;

iprove [] <<(x + y) + z = x + y + z>>;;

iprove [] <<length(append(x,y)) = length(x) + length(y)>>;;

iprove [] <<append(append(x,y),z) = append(x,append(y,z))>>;;

iprove [] <<append(x,nil) = x>>;;

iprove [<<append(append(x,y),z) = append(x,append(y,z))>>;
        <<append(x,nil) = x>>]
        <<rev(append(x,y)) = append(rev(y),rev(x))>>;;

iprove [<<rev(append(x,y)) = append(rev(y),rev(x))>>;
        <<append(x,nil) = x>>;
        <<append(append(x,y),z) = append(x,append(y,z))>>]
        <<rev(rev(x)) = x>>;;

(* ------------------------------------------------------------------------- *)
(* Here it's not immediately so obvious since we get extra equs.             *)
(* ------------------------------------------------------------------------- *)

iprove [] <<rev(rev(x)) = x>>;;

*)
(* ========================================================================= *)
(* Ordered rewriting.                                                        *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

(* ------------------------------------------------------------------------- *)
(* Rewriting constrained by ordering.                                        *)
(* ------------------------------------------------------------------------- *)

let orewrite1 ord eq t =
  match eq with
    (Atom(R("=",[l;r]))) ->
        let t' = termsubst (term_match l t) r in
        if ord t t' then t' else failwith "orewrite1: not ordered"
  | _ -> failwith "orewrite1: not an equation";;

(* ------------------------------------------------------------------------- *)
(* Rewriting with set of ordered and unordered equations.                    *)
(* ------------------------------------------------------------------------- *)

let orewrite ord (oeqs,ueqs) tm =
  try tryfind (fun eq -> rewrite1 eq tm) oeqs
  with Failure _ -> tryfind (fun eq -> orewrite1 ord eq tm) ueqs;;

(* ------------------------------------------------------------------------- *)
(* Split equations into orientable and non-orientable; orient the former.    *)
(* ------------------------------------------------------------------------- *)

let tryorient ord (Atom(R("=",[l;r])) as eq) (oeqs,ueqs) =
  if ord l r then (eq::oeqs,ueqs)
  else if ord r l then (Atom(R("=",[r;l]))::oeqs,ueqs)
  else (oeqs,eq::ueqs);;

(* ------------------------------------------------------------------------- *)
(* Evaluate "hypothetical" LPO based on ordering of variables.               *)
(* ------------------------------------------------------------------------- *)

let rec lpoh_gt vord w s t =
  match (s,t) with
    (Var y,Var x) -> vord y > vord x
  | (_,Var x) -> exists (fun y -> vord y >= vord x) (fvt s)
  | (Fn(f,fargs),Fn(g,gargs)) ->
        exists (fun si -> lpoh_ge vord w si t) fargs or
        forall (lpoh_gt vord w s) gargs &
        (f = g & lexord (lpoh_gt vord w) fargs gargs or
         w (f,length fargs) (g,length gargs))
    | _ -> false

and lpoh_ge vord w s t = (s = t) or lpoh_gt vord w s t;;

(* ------------------------------------------------------------------------- *)
(* All ways to identify subsets of the variables in a formula.               *)
(* ------------------------------------------------------------------------- *)

let allpartitions =
  let allinsertions x l acc =
    itlist (fun p acc -> ((x::p)::(subtract l [p])) :: acc) l
           (([x]::l)::acc) in
  fun l -> itlist (fun h y -> itlist (allinsertions h) y []) l [[]];;

let identify vars fn =
  let x = Var(hd vars) in itlist (fun v -> v |-> x) vars fn;;

let allidentifications fm =
  let fvs = fv fm in
  map (fun p -> formsubst(itlist identify p undefined) fm)
      (allpartitions fvs);;

(* ------------------------------------------------------------------------- *)
(* Find all orderings of variables.                                          *)
(* ------------------------------------------------------------------------- *)

let rec allpermutations l =
  if l = [] then [[]] else
  itlist (fun h acc -> map (fun t -> h::t)
                (allpermutations (subtract l [h])) @ acc) l [];;

let allvarorders l =
  map (fun vlis x -> index x vlis) (allpermutations l);;

(* ------------------------------------------------------------------------- *)
(* Test critical triple for joinability under all variable orders.           *)
(* ------------------------------------------------------------------------- *)

let ojoinable ord oueqs (Atom(R("=",[s;t]))) =
  depth (orewrite ord oueqs) s = depth (orewrite ord oueqs) t;;

let allojoinable w oueqs eq =
  forall (fun eq' ->
            forall (fun vord -> ojoinable (lpoh_gt vord w) oueqs eq')
                   (allvarorders(fv eq')))
         (allidentifications eq);;

(* ------------------------------------------------------------------------- *)
(* Find the critical pairs not joinable by naive variable order splits.      *)
(* ------------------------------------------------------------------------- *)

let rec unjoined w ((oeqs,ueqs as oueqs),unj,critts) =
  match critts with
    (Atom(R("=",[s;t])) as eq)::ocritts ->
        let s' = depth (orewrite ord oueqs) s
        and t' = depth (orewrite ord oueqs) t in
        if s' = t' then unjoined w (oueqs,unj,ocritts)
        else if allojoinable w oueqs eq
        then unjoined w (oueqs,unj,ocritts)
        else unjoined w (oueqs,Atom(R("=",[s';t']))::unj,ocritts)
  | [] -> unj;;

(* ------------------------------------------------------------------------- *)
(* Overall function to return possibly-unjoinable critical pairs.            *)
(* ------------------------------------------------------------------------- *)

let unjoinables plis eqs =
  let w = weight plis in
  let ord = lpo_gt w in
  let oueqs = itlist (tryorient ord) eqs ([],[]) in
  let critts = unions (allpairs critical_pairs eqs eqs) in
  unjoined w (oueqs,[],critts);;

(* ------------------------------------------------------------------------- *)
(* Example: pure AC.                                                         *)
(* ------------------------------------------------------------------------- *)

(*
let eqs = [<<x * y = y * x>>; <<(x * y) * z = x * y * z>>];;

unjoinables ["*"] eqs;;

(* ------------------------------------------------------------------------- *)
(* 4.2: associativity and commutativity.                                     *)
(* ------------------------------------------------------------------------- *)

let eqs =
 [<<x * y = y * x>>;
  <<(x * y) * z = x * y * z>>;
  <<x * y * z = y * x * z>>];;

unjoinables ["*"] eqs;;

(* ------------------------------------------------------------------------- *)
(* Example of normalizing expressions.                                       *)
(* ------------------------------------------------------------------------- *)

let ord = lpo_gt (fun (s,n) (s',n') -> s > s' or s = s' & n > n');;

let acnorm =
  depth(orewrite ord (itlist (tryorient ord) eqs ([],[])));;

acnorm <<|(4 * 3) * (1 * 5 * 1 * 2)|>>;;

(* ------------------------------------------------------------------------- *)
(* 4.4: associativity, commutativity and idempotence.                        *)
(* ------------------------------------------------------------------------- *)

let eqs =
 [<<(x * y) * z = x * y * z>>;
  <<x * y = y * x>>;
  <<x * (y * z) = y * (x * z)>>;
  <<x * x = x>>;
  <<x * (x * y) = x * y>>];;

unjoinables ["*"] eqs;;

(* ------------------------------------------------------------------------- *)
(* 4.8: Boolean rings.                                                       *)
(* ------------------------------------------------------------------------- *)

let eqs =
 [<<x + y = y + x>>;               <<x * y = y * x>>;
  <<x + y + z = y + x + z>>;       <<x * y * z = y * x * z>>;
  <<x + x = 0>>;                   <<x + 0 = x>>;
  <<0 + x = x>>;                   <<x * x = x>>;
  <<1 * x = x>>;                   <<x * 1 = x>>;
  <<x * 0 = 0>>;                   <<0 * x = 0>>;
  <<(x * y) * z = x * y * z>>;      <<(x + y) + z = x + y + z>>;
  <<x * (y + z) = x * y + x * z>>; <<(x + y) * z = x * z + y * z>>;
  <<x * (x * y) = x * y>>;         <<x + (x + y) = y>>];;

unjoinables ["0"; "1"; "+"; "*"] eqs;;

(* ------------------------------------------------------------------------- *)
(* Translation to propositional logic.                                       *)
(* ------------------------------------------------------------------------- *)

let rec prop_of_boolterm tm =
  match tm with
    Fn("+",[p;q]) -> Not(Iff(prop_of_boolterm p,prop_of_boolterm q))
  | Fn("*",[p;q]) -> And(prop_of_boolterm p,prop_of_boolterm q)
  | Fn("0",[]) -> False
  | Fn("1",[]) -> True
  | Var x -> Atom(R(x,[]));;

let prop_of_bool (Atom(R("=",[s;t]))) =
  Iff(prop_of_boolterm s,prop_of_boolterm t);;

forall tautology (map prop_of_bool eqs);;
*)

(* ------------------------------------------------------------------------- *)
(* Translation back.                                                         *)
(* ------------------------------------------------------------------------- *)

let rec bool_of_prop fm =
  match fm with
    False -> Fn("0",[])
  | True -> Fn("1",[])
  | Atom(R(p,[])) -> Fn(p,[])
  | Not(p) -> Fn("+",[bool_of_prop p; Fn("1",[])])
  | And(p,q) -> Fn("*",[bool_of_prop p; bool_of_prop q])
  | Or(p,q) -> let p' = bool_of_prop p and q' = bool_of_prop q in
               Fn("+",[p'; Fn("+",[q'; Fn("*",[p';q'])])])
  | Imp(p,q) -> bool_of_prop(Or(Not p,q))
  | Iff(p,q) -> let p' = bool_of_prop p and q' = bool_of_prop q in
                Fn("+",[p'; Fn("+",[q'; Fn("1",[])])]);;

(* ------------------------------------------------------------------------- *)
(* Canonical simplifier for Boolean rings.                                   *)
(* ------------------------------------------------------------------------- *)

(*
let ord =
  let w f g =
    match (f,g) with
      (("*",2),("+",2)) -> true
    | ((_,2),(_,0)) -> true
    | ((s,0),(s',0)) -> s > s'
    | _ -> false in
  lpo_gt w;;

let boolnorm =
  depth(orewrite ord (itlist (tryorient ord) eqs ([],[])));;

(* ------------------------------------------------------------------------- *)
(* Examples.                                                                 *)
(* ------------------------------------------------------------------------- *)

boolnorm (bool_of_prop <<p /\ q \/ ~p \/ ~q>>);;

boolnorm (bool_of_prop <<(p ==> q) \/ (q ==> p)>>);;

boolnorm (bool_of_prop <<p \/ q ==> q \/ (p <=> q)>>);;

(* ------------------------------------------------------------------------- *)
(* 4.5: Groups of exponent two                                               *)
(* ------------------------------------------------------------------------- *)

let eqs =
 [<<(x * y) * z = x * y * z>>;
  <<x * y = y * x>>;
  <<x * (y * z) = y * (x * z)>>;
  <<x * x = 1>>;
  <<x * (x * y) = y>>;
  <<x * 1 = x>>;
  <<1 * x = x>>];;

unjoinables ["1"; "*"] eqs;;

(* ------------------------------------------------------------------------- *)
(* 4.7: Distributivity.                                                      *)
(* ------------------------------------------------------------------------- *)

let eqs =
 [<<(x * y) * z = x * y * z>>;
  <<x * y = y * x>>;
  <<x * y * z = y * x * z>>;
  <<x * (y + z) = x * y + x * z>>;
  <<(x + y) * z = x * z + y * z>>];;

unjoinables ["+"; "*"] eqs;;

*)
(* ========================================================================= *)
(* Equality elimination including Brand transformation and relatives.        *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

(*
(meson ** equalitize)
 <<(forall x y z. (x * y) * z = x * (y * z)) <=>
   (forall u v w x y z.
         (x * y = u) /\ (y * z = w) ==> ((x * w = v) <=> (u * z = v)))>>;;

(* ------------------------------------------------------------------------- *)
(* Example of using 3-place predicate for group theory.                      *)
(* ------------------------------------------------------------------------- *)

meson
 <<(forall x. P(1,x,x)) /\
   (forall x. P(i(x),x,1)) /\
   (forall u v w x y z. P(x,y,u) /\ P(y,z,w) ==> (P(x,w,v) <=> P(u,z,v)))
   ==> forall x. P(x,1,x)>>;;

meson
 <<(forall x. P(1,x,x)) /\
   (forall x. P(i(x),x,1)) /\
   (forall u v w x y z. P(x,y,u) /\ P(y,z,w) ==> (P(x,w,v) <=> P(u,z,v)))
   ==> forall x. P(x,i(x),1)>>;;

(* ------------------------------------------------------------------------- *)
(* The x^2 = 1 implies Abelian problem.                                      *)
(* ------------------------------------------------------------------------- *)

meson
 <<(forall x. P(1,x,x)) /\
   (forall x. P(x,x,1)) /\
   (forall u v w x y z. P(x,y,u) /\ P(y,z,w) ==> (P(x,w,v) <=> P(u,z,v)))
   ==> forall a b c. P(a,b,c) ==> P(b,a,c)>>;;

(* ------------------------------------------------------------------------- *)
(* See how efficiency drops when we assert completeness.                     *)
(* ------------------------------------------------------------------------- *)

meson
 <<(forall x. P(1,x,x)) /\
   (forall x. P(x,x,1)) /\
   (forall x y. exists z. P(x,y,z)) /\
   (forall u v w x y z. P(x,y,u) /\ P(y,z,w) ==> (P(x,w,v) <=> P(u,z,v)))
   ==> forall a b c. P(a,b,c) ==> P(b,a,c)>>;;

(* ------------------------------------------------------------------------- *)
(* Lemma for equivalence elimination.                                        *)
(* ------------------------------------------------------------------------- *)

meson
 <<(forall x. R(x,x)) /\
   (forall x y. R(x,y) ==>  R(y,x)) /\
   (forall x y z. R(x,y) /\ R(y,z) ==> R(x,z))
   <=> (forall x y. R(x,y) <=> (forall z. R(x,z) <=> R(y,z)))>>;;

(* ------------------------------------------------------------------------- *)
(* Same thing for reflexivity and transitivity without symmetry.             *)
(* ------------------------------------------------------------------------- *)

meson
 <<(forall x. R(x,x)) /\
   (forall x y z. R(x,y) /\ R(y,z) ==> R(x,z))
   <=> (forall x y. R(x,y) <=> (forall z. R(y,z) ==> R(x,z)))>>;;

(* ------------------------------------------------------------------------- *)
(* And for just symmetry.                                                    *)
(* ------------------------------------------------------------------------- *)

meson
 <<(forall x y. R(x,y) ==>  R(y,x)) <=>
   (forall x y. R(x,y) <=> R(x,y) /\ R(y,x))>>;;

(* ------------------------------------------------------------------------- *)
(* Show how Equiv' reduces to triviality.                                    *)
(* ------------------------------------------------------------------------- *)

meson
 <<(forall x. (forall w. R'(x,w) <=> R'(x,w))) /\
   (forall x y. (forall w. R'(x,w) <=> R'(y,w))
                ==> (forall w. R'(y,w) <=> R'(x,w))) /\
   (forall x y z. (forall w. R'(x,w) <=> R'(y,w)) /\
                  (forall w. R'(y,w) <=> R'(z,w))
                  ==> (forall w. R'(x,w) <=> R'(z,w)))>>;;

(* ------------------------------------------------------------------------- *)
(* More auxiliary proofs for Brand's S and T modification.                   *)
(* ------------------------------------------------------------------------- *)

meson
 <<(forall x y. R(x,y) <=> (forall z. R'(x,z) <=> R'(y,z))) /\
   (forall x. R'(x,x))
   ==> forall x y. ~R'(x,y) ==> ~R(x,y)>>;;

meson
 <<(forall x y. R(x,y) <=> (forall z. R'(y,z) ==> R'(x,z))) /\
   (forall x. R'(x,x))
   ==> forall x y. ~R'(x,y) ==> ~R(x,y)>>;;

meson
 <<(forall x y. R(x,y) <=> R'(x,y) /\ R'(y,x))
   ==> forall x y. ~R'(x,y) ==> ~R(x,y)>>;;
*)

(* ------------------------------------------------------------------------- *)
(* Brand's S and T modifications on clauses.                                 *)
(* ------------------------------------------------------------------------- *)

let rec modify_S cl =
  try let (s,t) = tryfind dest_eq cl in
      let eq1 = Atom(R("=",[s;t])) and eq2 = Atom(R("=",[t;s])) in
      let sub = modify_S (subtract cl [eq1]) in
      map (fun s -> eq1::s) sub @ map (fun s -> eq2::s) sub
  with Failure _ -> [cl];;

let rec modify_T cl =
  match cl with
    [] -> []
  | (Atom(R("=",[s;t])) as eq)::ps ->
        let ps' = modify_T ps in
        let w = Var(variant "w" (itlist (union ** fv) ps' (fv eq))) in
        (Not(Atom(R("=",[t;w]))))::Atom(R("=",[s;w]))::ps'
  | p::ps -> p::(modify_T ps);;

(* ------------------------------------------------------------------------- *)
(* Finding nested non-variable subterms.                                     *)
(* ------------------------------------------------------------------------- *)

let find_nonvar = find (function (Var x) -> false | _ -> true);;

let find_nestnonvar tm =
  match tm with
    Var x -> failwith "findnvsubt"
  | Fn(f,args) -> find_nonvar args;;

let rec find_nvsubterm fm =
  match fm with
    Atom(R("=",[s;t])) -> tryfind find_nestnonvar [s;t]
  | Atom(R(p,args)) -> find_nonvar args
  | Not p -> find_nvsubterm p;;

(* ------------------------------------------------------------------------- *)
(* Replacement (substitution for non-variable) in term and literal.          *)
(* ------------------------------------------------------------------------- *)

let rec replacet rfn tm =
  try apply rfn tm with Failure _ ->
  match tm with
    Fn(f,args) -> Fn(f,map (replacet rfn) args)
  | _ -> tm;;

let replace rfn fm =
  onatoms (fun (R(p,a)) -> Atom(R(p,map (replacet rfn) a))) fm;;

(* ------------------------------------------------------------------------- *)
(* E-modification of a clause.                                               *)
(* ------------------------------------------------------------------------- *)

let rec emodify fvs cls =
  match cls with
    [] -> []
  | cl::ocls ->
        try let t = find_nvsubterm cl in
            let w = variant "w" fvs in
            let cls' = map (replace (t := Var w)) cls in
            emodify (w::fvs) (Not(Atom(R("=",[t;Var w])))::cls')
        with Failure _ -> cl::(emodify fvs ocls);;

let modify_E cls = emodify (itlist (union ** fv) cls []) cls;;

(* ------------------------------------------------------------------------- *)
(* Overall Brand transformation.                                             *)
(* ------------------------------------------------------------------------- *)

let brand cls =
  let cls1 = map modify_E cls in
  let cls2 = itlist (union ** modify_S) cls1 [] in
  [Atom(R("=",[Var"x";Var "x"]))]::(map modify_T cls2);;

(* ------------------------------------------------------------------------- *)
(* Incorporation into MESON.                                                 *)
(* ------------------------------------------------------------------------- *)

let bpuremeson fm =
  let cls = brand(clausal(specialize(pnf fm))) in
  let rules = itlist ((@) ** contrapositives) cls [] in
  deepen (fun n -> expand rules [] False
                     (fun x -> x) (undefined,n,0); n) 0;;

let bmeson fm =
  let fm1 = askolemize(Not(generalize fm)) in
  map (bpuremeson ** list_conj) (simpdnf fm1);;

(* ------------------------------------------------------------------------- *)
(* Examples.                                                                 *)
(* ------------------------------------------------------------------------- *)

(*
let emeson fm = meson (equalitize fm);;

let ewd =
 <<(forall x. f(x) ==> g(x)) /\
   (exists x. f(x)) /\
   (forall x y. g(x) /\ g(y) ==> x = y)
   ==> forall y. g(y) ==> f(y)>>;;

time bmeson ewd;;
time emeson ewd;;

*)
(* ========================================================================= *)
(* Paramodulation.                                                           *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

let rec overlapv (l,r) tm rfn =
  match tm with
    Fn(f,args) ->
        listcases (overlapv (l,r)) (fun i a -> rfn i (Fn(f,a))) args
                  (try [rfn (fullunify [l,tm]) r] with Failure _ -> [])
  | Var x -> [rfn (fullunify [l,tm]) r];;

(* ------------------------------------------------------------------------- *)
(* Find paramodulations with l = r inside a literal fm.                      *)
(* ------------------------------------------------------------------------- *)

let rec overlapl (l,r) fm rfn =
  match fm with
    Atom(R(f,args)) -> listcases (overlapv (l,r))
                              (fun i a -> rfn i (Atom(R(f,a)))) args []
  | Not(p) -> overlapl (l,r) p (fun i p -> rfn i (Not(p)))
  | _ -> failwith "overlapl: not a literal";;

(* ------------------------------------------------------------------------- *)
(* Now find paramodulations within a clause.                                 *)
(* ------------------------------------------------------------------------- *)

let overlapc (l,r) cl rfn acc = listcases (overlapl (l,r)) rfn cl acc;;

(* ------------------------------------------------------------------------- *)
(* Overall paramodulation of ocl by equations in pcl.                        *)
(* ------------------------------------------------------------------------- *)

let paramodulate pcl ocl =
  itlist (fun eq ->
                let pcl' = subtract pcl [eq] in
                let (l,r) = dest_eq eq
                and rfn i ocl' = smap (formsubst i) (pcl' @ ocl') in
                overlapc (l,r) ocl rfn ** overlapc (r,l) ocl rfn)
         (filter (can dest_eq) pcl) [];;

let paramodulate_clauses cls1 cls2 =
  let cls1' = rename "x" cls1 and cls2' = rename "y" cls2 in
  paramodulate cls1' cls2' @ paramodulate cls2' cls1';;

(* ------------------------------------------------------------------------- *)
(* Incorporation into resolution loop.                                       *)
(* ------------------------------------------------------------------------- *)

let rec paraloop (used,unused) =
  match unused with
    [] -> failwith "No proof found"
  | cls::ros ->
        print_string(string_of_int(length used) ^ " used; "^
                     string_of_int(length unused) ^ " unused.");
        print_newline();
        let used' = insert cls used in
        let news =
          itlist (@) (mapfilter (resolve_clauses cls) used')
            (itlist (@) (mapfilter (paramodulate_clauses cls) used') []) in
        if mem [] news then true else
        paraloop(used',itlist (incorporate cls) news ros);;

let pure_paramodulation fm =
  paraloop([],[Atom(R("=",[Var "x"; Var "x"]))]::
              (filter (non tautologous)
                      (clausal(specialize(pnf fm)))));;

let paramodulation fm =
  let fm1 = askolemize(Not(generalize fm)) in
  map (pure_paramodulation ** list_conj) (simpdnf fm1);;

(* ------------------------------------------------------------------------- *)
(* Test.                                                                     *)
(* ------------------------------------------------------------------------- *)

let ewd =
 <<(forall x. f(x) ==> g(x)) /\
   (exists x. f(x)) /\
   (forall x y. g(x) /\ g(y) ==> x = y)
   ==> forall y. g(y) ==> f(y)>>;;

(*
paramodulation ewd;;
*)

let fm =
 <<forall c. f(f(f(f(f(c))))) = c /\ f(f(f(c))) = c ==> f(c) = c>>;;

(*
paramodulation fm;;
*)

let ewd' =
 <<(forall x. f(f(x)) = f(x)) /\ (forall x. exists y. f(y) = x)
   ==> forall x. f(x) = x>>;;

(*
paramodulation ewd';;
*)
(* ========================================================================= *)
(* Special procedures for decidable subsets of first order logic.            *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

(* ------------------------------------------------------------------------- *)
(* The Los example; see how Skolemized form has no non-nullary functions.    *)
(* ------------------------------------------------------------------------- *)

(*
let los =
 <<(forall x y z. P(x,y) /\ P(y,z) ==> P(x,z)) /\
   (forall x y z. Q(x,y) /\ Q(y,z) ==> Q(x,z)) /\
   (forall x y. P(x,y) ==> P(y,x)) /\
   (forall x y. P(x,y) \/ Q(x,y))
   ==> (forall x y. P(x,y)) \/ (forall x y. Q(x,y))>>;;
skolemize(Not los);;

(* ------------------------------------------------------------------------- *)
(* The old DP procedure works.                                               *)
(* ------------------------------------------------------------------------- *)

davisputnam los;;
*)

(* ------------------------------------------------------------------------- *)
(* However, we can just form all the ground instances.                       *)
(* ------------------------------------------------------------------------- *)

let aedecide fm =
  let sfm = skolemize(Not fm) in
  let fvs = fv sfm
  and fns = functions sfm in
  let allfns =
     if exists (fun (_,ar) -> ar = 0) fns then fns
     else insert ("c",0) fns in
  let consts,funcs = partition (fun (_,ar) -> ar = 0) allfns in
  if funcs <> [] then failwith "Not decidable" else
  let cntms = smap (fun (c,_) -> Fn(c,[])) consts in
  let alltuples = groundtuples cntms [] 0 (length fvs) in
  let cjs = clausal sfm in
  let grounds = map
   (fun tup -> let inst = instantiate fvs tup in
               smap (smap (formsubst inst)) cjs) alltuples in
  not(dpll(unions grounds));;

(* ------------------------------------------------------------------------- *)
(* In this case it's quicker.                                                *)
(* ------------------------------------------------------------------------- *)

(*
aedecide los;;
*)

(* ------------------------------------------------------------------------- *)
(* A nicer alternative is to modify the Herbrand loop.                       *)
(* ------------------------------------------------------------------------- *)

let rec herbloop mfn tfn fl0 cntms funcs fvs n fl tried tuples =
  print_string(string_of_int(length tried)^" ground instances tried; "^
               string_of_int(length fl)^" items in list");
  print_newline();
  match tuples with
    [] -> let newtups = groundtuples cntms funcs n (length fvs) in
          if newtups = [] then false else
          herbloop mfn tfn fl0 cntms funcs fvs (n + 1) fl tried newtups
  | tup::tups ->
          let fl' = mfn fl0 (formsubst(instantiate fvs tup)) fl in
          not(tfn fl') or
          herbloop mfn tfn fl0 cntms funcs fvs n fl' (tup::tried) tups;;

(* ------------------------------------------------------------------------- *)
(* Show how we need to do PNF transformation with care.                      *)
(* ------------------------------------------------------------------------- *)

(*
let fm = <<(forall x. p(x)) \/ (exists y. p(y))>>;;

pnf fm;;

nnf(Not(pnf(nnf(simplify los))));;

pnf(nnf(simplify(Not los)));;

(* ------------------------------------------------------------------------- *)
(* Also the group theory problem.                                            *)
(* ------------------------------------------------------------------------- *)

aedecide
 <<(forall x. P(1,x,x)) /\
   (forall x. P(x,x,1)) /\
   (forall u v w x y z. P(x,y,u) /\ P(y,z,w) ==> (P(x,w,v) <=> P(u,z,v)))
   ==> forall a b c. P(a,b,c) ==> P(b,a,c)>>;;

aedecide
 <<(forall x. P(x,x,1)) /\
   (forall u v w x y z. P(x,y,u) /\ P(y,z,w) ==> (P(x,w,v) <=> P(u,z,v)))
   ==> forall a b c. P(a,b,c) ==> P(b,a,c)>>;;

(* ------------------------------------------------------------------------- *)
(* A bigger example.                                                         *)
(* ------------------------------------------------------------------------- *)

let p29 =
 <<(exists x. P(x)) /\ (exists x. G(x)) ==>
   ((forall x. P(x) ==> H(x)) /\ (forall x. G(x) ==> J(x)) <=>
    (forall x y. P(x) /\ G(y) ==> H(x) /\ J(y)))>>;;

aedecide p29;;

davisputnam p29;;

(* ------------------------------------------------------------------------- *)
(* The following, however, doesn't work with aedecide.                       *)
(* ------------------------------------------------------------------------- *)

let p18 = <<exists y. forall x. P(y) ==> P(x)>>;;

(*** aedecide p18;; ***)

davisputnam p18;;
*)

(* ------------------------------------------------------------------------- *)
(* Simple-minded miniscoping procedure.                                      *)
(* ------------------------------------------------------------------------- *)

let separate x cjs =
  let yes,no = partition (mem x ** fv) cjs in
  if yes = [] then list_conj no
  else if no = [] then Exists(x,list_conj yes)
  else And(Exists(x,list_conj yes),list_conj no);;

let rec pushquant x p =
  if not (mem x (fv p)) then p else
  let djs = purednf(nnf p) in
  list_disj (map (separate x) djs);;

let rec miniscope fm =
  match fm with
    Not p -> Not(miniscope p)
  | And(p,q) -> And(miniscope p,miniscope q)
  | Or(p,q) -> Or(miniscope p,miniscope q)
  | Forall(x,p) -> Not(pushquant x (Not(miniscope p)))
  | Exists(x,p) -> pushquant x (miniscope p)
  | _ -> fm;;

(* ------------------------------------------------------------------------- *)
(* Examples.                                                                 *)
(* ------------------------------------------------------------------------- *)

(*
miniscope(nnf <<exists y. forall x. P(y) ==> P(x)>>);;

let fm = miniscope(nnf
 <<(forall x y. exists z. forall w. P(x) /\ Q(y) ==> R(z) /\ U(w))
   ==> (exists x y. P(x) /\ Q(y)) ==> (exists z. R(z))>>);;

pnf(nnf fm);;
*)

(* ------------------------------------------------------------------------- *)
(* Stronger version of "aedecide" similar to Wang's classic procedure.       *)
(* ------------------------------------------------------------------------- *)

let wang fm =
  let fm' = miniscope(nnf(simplify fm)) in aedecide fm';;

(* ------------------------------------------------------------------------- *)
(* It works well on simple monadic formulas.                                 *)
(* ------------------------------------------------------------------------- *)

(*
let p18 = wang
 <<exists y. forall x. P(y) ==> P(x)>>;;

let p19 = wang
 <<exists x. forall y z. (P(y) ==> Q(z)) ==> P(x) ==> Q(x)>>;;

let p20 = wang
 <<(forall x y. exists z. forall w. P(x) /\ Q(y) ==> R(z) /\ U(w))
   ==> (exists x y. P(x) /\ Q(y)) ==> (exists z. R(z))>>;;

let p21 = wang
 <<(exists x. P ==> Q(x)) /\ (exists x. Q(x) ==> P)
   ==> (exists x. P <=> Q(x))>>;;

let p22 = wang
 <<(forall x. P <=> Q(x)) ==> (P <=> (forall x. Q(x)))>>;;

(* ------------------------------------------------------------------------- *)
(* But not on this one!                                                      *)
(* ------------------------------------------------------------------------- *)

let p34 =
 <<((exists x. forall y. P(x) <=> P(y)) <=>
    ((exists x. Q(x)) <=> (forall y. Q(y)))) <=>
   ((exists x. forall y. Q(x) <=> Q(y)) <=>
    ((exists x. P(x)) <=> (forall y. P(y))))>>;;

pnf(nnf(miniscope(nnf p34)));;
*)

(* ------------------------------------------------------------------------- *)
(* Checking classic Aristotelean syllogisms.                                 *)
(* ------------------------------------------------------------------------- *)

type sylltype =
   Syll_A     (* All S are P      *)
 | Syll_E     (* No S are P       *)
 | Syll_I     (* Some S are P     *)
 | Syll_O;;   (* Some S are not P *)

let syllprem ty (s,p) =
  let sx = Atom(R(s,[Var "x"])) and px = Atom(R(p,[Var "x"])) in
  match ty with
    Syll_A -> Forall("x",Imp(sx,px))
  | Syll_E -> Forall("x",Imp(sx,Not(px)))
  | Syll_I -> Exists("x",And(sx,px))
  | Syll_O -> Exists("x",And(sx,Not(px)));;

let anglicize_prem fm =
  match fm with
    Forall(_,Imp(Atom(R(s,[Var _])),Atom(R(p,[Var _])))) ->
        "all "^s^" are "^p
  | Forall(_,Imp(Atom(R(s,[Var _])),Not(Atom(R(p,[Var _]))))) ->
        "no "^s^" are "^p
  | Exists(_,And(Atom(R(s,[Var _])),Atom(R(p,[Var _])))) ->
        "some "^s^" are "^p
  | Exists(_,And(Atom(R(s,[Var _])),Not(Atom(R(p,[Var _]))))) ->
        "some "^s^" are not "^p;;

let anglicize_syllogism (Imp(And(t1,t2),t3)) =
      "If " ^ anglicize_prem t1 ^
    " and " ^ anglicize_prem t2 ^
  ", then " ^ anglicize_prem t3;;

let all_possible_syllogisms =
  let sylltypes = [Syll_A; Syll_E; Syll_I; Syll_O] in
  let prems1 = allpairs syllprem sylltypes ["M","P"; "P","M"]
  and prems2 = allpairs syllprem sylltypes ["S","M"; "M","S"]
  and prems3 = allpairs syllprem sylltypes ["S","P"] in
  allpairs (fun p12 p3 -> Imp(p12,p3))
           (allpairs (fun p1 p2 -> And(p1,p2)) prems1 prems2) prems3;;

let all_valid_syllogisms = filter aedecide all_possible_syllogisms;;

(*
length all_valid_syllogisms;;

map anglicize_syllogism all_valid_syllogisms;;
*)

let all_cond_valid_syllogisms p =
  let all_modified_syllogisms =
    map (fun q -> Imp(p,q)) all_possible_syllogisms in
  filter aedecide all_modified_syllogisms;;

(*
length(all_cond_valid_syllogisms <<exists x. S(x)>>);;

length(all_cond_valid_syllogisms
 <<(exists x. P(x)) /\ (exists x. M(x))>>);;

length(all_cond_valid_syllogisms
 <<(exists x. S(x)) /\ (exists x. M(x)) /\ (exists x. P(x))>>);;

length(all_cond_valid_syllogisms
 <<((forall x. P(x) ==> M(x)) ==> (exists x. P(x) /\ M(x))) /\
   ((forall x. P(x) ==> ~M(x)) ==> (exists x. P(x) /\ ~M(x))) /\
   ((forall x. M(x) ==> P(x)) ==> (exists x. M(x) /\ P(x))) /\
   ((forall x. M(x) ==> ~P(x)) ==> (exists x. M(x) /\ ~P(x))) /\
   ((forall x. S(x) ==> M(x)) ==> (exists x. S(x) /\ M(x))) /\
   ((forall x. S(x) ==> ~M(x)) ==> (exists x. S(x) /\ ~M(x))) /\
   ((forall x. M(x) ==> S(x)) ==> (exists x. M(x) /\ S(x))) /\
   ((forall x. M(x) ==> ~S(x)) ==> (exists x. M(x) /\ ~S(x)))>>);;
*)

(* ------------------------------------------------------------------------- *)
(* Decide a formula on all models of size n.                                 *)
(* ------------------------------------------------------------------------- *)

let rec alltuples n l =
  if n = 0 then [[]] else
  let tups = alltuples (n - 1) l in
  allpairs (fun h t -> h::t) l tups;;

let allmappings dom ran =
  itlist (fun p -> allpairs (valmod p) ran) dom [undef];;

let alldepmappings dom ran =
  itlist (fun (p,n) -> allpairs (valmod p) (ran n)) dom [undef];;

let allfunctions dmn n = allmappings (alltuples n dmn) dmn;;

let allpredicates dmn n = allmappings (alltuples n dmn) [false;true];;

let decide_finite n fm =
  let funcs = functions fm
  and preds = predicates fm in
  let dmn = 1 -- n in
  let finterps = alldepmappings funcs (allfunctions dmn)
  and predinterps = alldepmappings preds (allpredicates dmn) in
  let interps = allpairs (fun fi pi -> Interp(dmn,fi,pi))
                         finterps predinterps in
  let fm' = generalize fm in
  forall (fun md -> holds md undefined fm') interps;;

(* ------------------------------------------------------------------------- *)
(* Decision procedure in principle for formulas with finite model property.  *)
(* ------------------------------------------------------------------------- *)

let limitedmeson n fm =
  let cls = clausal(specialize(pnf fm)) in
  let rules = itlist ((@) ** contrapositives) cls [] in
  expand rules [] False (fun x -> x) (undefined,n,0);;

let limmeson n fm =
  let fm1 = askolemize(Not(generalize fm)) in
  map (limitedmeson n ** list_conj) (simpdnf fm1);;

let decide_fmp =
  let rec test n =
    try limmeson n fm; true with Failure _ ->
    if decide_finite n fm then test (n + 1) else false in
  test 1;;

(* ------------------------------------------------------------------------- *)
(* Semantic decision procedure for the monadic fragment.                     *)
(* ------------------------------------------------------------------------- *)

let decide_monadic fm =
  let funcs = functions fm
  and preds = predicates fm in
  let monadic,other = partition (fun (_,ar) -> ar = 1) preds in
  if funcs <> [] or exists (fun (_,ar) -> ar > 1) other
  then failwith "Not in the monadic subset" else
  let n = funpow (length monadic) (( * ) 2) 1 in
  decide_finite n fm;;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)

(*
decide_monadic p34;;
*)
(* ========================================================================= *)
(* Introduction to quantifier elimination.                                   *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

let rec disjuncts fm =
  match fm with
    Or(p,q) -> disjuncts p @ disjuncts q
  | _ -> [fm];;

let rec conjuncts fm =
  match fm with
    And(p,q) -> conjuncts p @ conjuncts q
  | _ -> [fm];;

(* ------------------------------------------------------------------------- *)
(* Lift procedure given literal modifier, formula normalizer, and a  basic   *)
(* elimination procedure for existential formulas with conjunctive body.     *)
(* ------------------------------------------------------------------------- *)

let lift_qelim afn nfn qfn =
  let rec qelift vars fm =
    match fm with
    | Atom(R(_,_)) -> afn vars fm
    | Not(p) -> Not(qelift vars p)
    | And(p,q) -> And(qelift vars p,qelift vars q)
    | Or(p,q) -> Or(qelift vars p,qelift vars q)
    | Imp(p,q) -> Imp(qelift vars p,qelift vars q)
    | Iff(p,q) -> Iff(qelift vars p,qelift vars q)
    | Forall(x,p) -> Not(qelift vars (Exists(x,Not p)))
    | Exists(x,p) ->
          let djs = disjuncts(nfn(qelift (x::vars) p)) in
          list_disj(map (qelim x vars) djs)
    | _ -> fm
  and qelim x vars p =
    let cjs = conjuncts p in
    let ycjs,ncjs = partition (mem x ** fv) cjs in
    if ycjs = [] then p else
    let q = qfn vars (Exists(x,list_conj ycjs)) in
    itlist (fun p q -> And(p,q)) ncjs q in
  fun fm -> simplify(qelift (fv fm) fm);;

(* ------------------------------------------------------------------------- *)
(* Cleverer (proposisional) NNF with conditional and literal modification.   *)
(* ------------------------------------------------------------------------- *)

let cnnf lfn =
  let rec cnnf fm =
    match fm with
      And(p,q) -> And(cnnf p,cnnf q)
    | Or(p,q) -> Or(cnnf p,cnnf q)
    | Imp(p,q) -> Or(cnnf(Not p),cnnf q)
    | Iff(p,q) -> Or(And(cnnf p,cnnf q),And(cnnf(Not p),cnnf(Not q)))
    | Not(Not p) -> cnnf p
    | Not(And(p,q)) -> Or(cnnf(Not p),cnnf(Not q))
    | Not(Or(And(p,q),And(p',r))) when p' = negate p ->
         Or(cnnf (And(p,Not q)),cnnf (And(p',Not r)))
    | Not(Or(p,q)) -> And(cnnf(Not p),cnnf(Not q))
    | Not(Imp(p,q)) -> And(cnnf p,cnnf(Not q))
    | Not(Iff(p,q)) -> Or(And(cnnf p,cnnf(Not q)),
                          And(cnnf(Not p),cnnf q))
    | _ -> lfn fm in
  simplify ** cnnf ** simplify;;

(* ------------------------------------------------------------------------- *)
(* Initial literal simplifier and intermediate literal modifier.             *)
(* ------------------------------------------------------------------------- *)

let lfn_dlo fm =
  match fm with
    Not(Atom(R("<",[s;t]))) -> Or(Atom(R("=",[s;t])),Atom(R("<",[t;s])))
  | Not(Atom(R("=",[s;t]))) -> Or(Atom(R("<",[s;t])),Atom(R("<",[t;s])))
  | _ -> fm;;

(* ------------------------------------------------------------------------- *)
(* Simple example of dense linear orderings; this is the base function.      *)
(* ------------------------------------------------------------------------- *)

let dlobasic fm =
  match fm with
    Exists(x,p) ->
      let cjs = subtract (conjuncts p) [Atom(R("=",[Var x;Var x]))] in
      try let eqn = find
            (function (Atom(R("=",_))) -> true | _ -> false) cjs in
          let (Atom(R("=",[s;t]))) = eqn in
          let y = if s = Var x then t else s in
          list_conj(map (formsubst (x := y)) (subtract cjs [eqn]))
      with Failure _ ->
          if mem (Atom(R("<",[Var x;Var x]))) cjs then False else
          let l,r =
            partition (fun (Atom(R("<",[s;t]))) -> t = Var x) cjs in
          let lefts = map (fun (Atom(R("<",[l;_]))) -> l) l
          and rights = map (fun (Atom(R("<",[_;r]))) -> r) r in
          list_conj(allpairs (fun l r -> Atom(R("<",[l;r])))
                             lefts rights)
  | _ -> failwith "dlobasic";;

(* ------------------------------------------------------------------------- *)
(* Overall quelim procedure.                                                 *)
(* ------------------------------------------------------------------------- *)

let afn_dlo vars fm =
  match fm with
    Atom(R("<=",[s;t])) -> Not(Atom(R("<",[t;s])))
  | Atom(R(">=",[s;t])) -> Not(Atom(R("<",[s;t])))
  | Atom(R(">",[s;t])) -> Atom(R("<",[t;s]))
  | _ -> fm;;

let quelim_dlo =
  lift_qelim afn_dlo (dnf ** cnnf lfn_dlo) (fun v -> dlobasic);;

(* ------------------------------------------------------------------------- *)
(* Examples.                                                                 *)
(* ------------------------------------------------------------------------- *)

(*
quelim_dlo <<forall x y. exists z. z < x /\ z < y>>;;

quelim_dlo <<exists z. z < x /\ z < y>>;;

quelim_dlo <<exists z. x < z /\ z < y>>;;

quelim_dlo <<(forall x. x < a ==> x < b)>>;;

quelim_dlo <<forall a b. (forall x. x < a <=> x < b) <=> a = b>>;;

time quelim_dlo <<forall x. exists y. x < y>>;;

time quelim_dlo <<forall x y z. x < y /\ y < z ==> x < z>>;;

time quelim_dlo <<forall x y. x < y \/ (x = y) \/ y < x>>;;

time quelim_dlo <<exists x y. x < y /\ y < x>>;;

time quelim_dlo <<forall x y. exists z. z < x /\ x < y>>;;

time quelim_dlo <<exists z. z < x /\ x < y>>;;

time quelim_dlo <<forall x y. exists z. z < x /\ z < y>>;;

time quelim_dlo <<forall x y. x < y ==> exists z. x < z /\ z < y>>;;

time quelim_dlo
  <<forall x y. ~(x = y) ==> exists u. u < x /\ (y < u \/ x < y)>>;;

time quelim_dlo <<exists x. x = x>>;;

time quelim_dlo <<exists x. x = x /\ x = y>>;;

time quelim_dlo <<exists z. x < z /\ z < y>>;;

time quelim_dlo <<exists z. x <= z /\ z <= y>>;;

time quelim_dlo <<exists z. x < z /\ z <= y>>;;

time quelim_dlo <<forall x y z. exists u. u < x /\ u < y /\ u < z>>;;

time quelim_dlo <<forall y. x < y /\ y < z ==> w < z>>;;

time quelim_dlo <<forall x y. x < y>>;;

time quelim_dlo <<exists z. z < x /\ x < y>>;;

time quelim_dlo <<forall a b. (forall x. x < a ==> x < b) <=> a <= b>>;;

time quelim_dlo <<forall x. x < a ==> x < b>>;;

time quelim_dlo <<forall x. x < a ==> x <= b>>;;

time quelim_dlo <<forall a b. exists x. ~(x = a) \/ ~(x = b) \/ (a = b)>>;;

time quelim_dlo <<forall x y. x <= y \/ x > y>>;;

time quelim_dlo <<forall x y. x <= y \/ x < y>>;;
*)
(* ========================================================================= *)
(* Cooper's algorithm for Presburger arithmetic.                             *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

(* ------------------------------------------------------------------------- *)
(* Lift operations up to numerals.                                           *)
(* ------------------------------------------------------------------------- *)

let mk_numeral n = Fn(string_of_num n,[]);;

let dest_numeral =
  function (Fn(ns,[])) -> num_of_string ns
         | _ -> failwith "dest_numeral";;

let is_numeral = can dest_numeral;;

let numeral1 fn n = mk_numeral(fn(dest_numeral n));;

let numeral2 fn m n = mk_numeral(fn (dest_numeral m) (dest_numeral n));;

(* ------------------------------------------------------------------------- *)
(* Operations on canonical linear terms c1 * x1 + ... + cn * xn + k          *)
(*                                                                           *)
(* Note that we're quite strict: the ci must be present even if 1            *)
(* (but if 0 we expect the monomial to be omitted) and k must be there       *)
(* even if it's zero. Thus, it's a constant iff not an addition term.        *)
(* ------------------------------------------------------------------------- *)

let rec linear_cmul n tm =
  if n =/ Int 0 then Fn("0",[]) else
  match tm with
    Fn("+",[Fn("*",[c1; x1]); rest]) ->
        Fn("+",[Fn("*",[numeral1 (( */ ) n) c1; x1]);
                        linear_cmul n rest])
  | k -> numeral1 (( */ ) n) k;;

let earlierv vars (Var x) (Var y) = earlier vars x y;;

let rec linear_add vars tm1 tm2 =
  match (tm1,tm2) with
   (Fn("+",[Fn("*",[c1; x1]); rest1]),
    Fn("+",[Fn("*",[c2; x2]); rest2])) ->
        if x1 = x2 then
          let c = numeral2 (+/) c1 c2 in
          if c = Fn("0",[]) then linear_add vars rest1 rest2
          else Fn("+",[Fn("*",[c; x1]); linear_add vars rest1 rest2])
        else if earlierv vars x1 x2 then
          Fn("+",[Fn("*",[c1; x1]); linear_add vars rest1 tm2])
        else
          Fn("+",[Fn("*",[c2; x2]); linear_add vars tm1 rest2])
  | (Fn("+",[Fn("*",[c1; x1]); rest1]),_) ->
        Fn("+",[Fn("*",[c1; x1]); linear_add vars rest1 tm2])
  | (_,Fn("+",[Fn("*",[c2; x2]); rest2])) ->
        Fn("+",[Fn("*",[c2; x2]); linear_add vars tm1 rest2])
  | _ -> numeral2 (+/) tm1 tm2;;

let linear_neg tm = linear_cmul (Int(-1)) tm;;

let linear_sub vars tm1 tm2 = linear_add vars tm1 (linear_neg tm2);;

(* ------------------------------------------------------------------------- *)
(* Linearize a term.                                                         *)
(* ------------------------------------------------------------------------- *)

let rec lint vars tm =
  match tm with
    Var x -> Fn("+",[Fn("*",[Fn("1",[]); tm]); Fn("0",[])])
  | Fn("-",[t]) -> linear_neg (lint vars t)
  | Fn("+",[s;t]) -> linear_add vars (lint vars s) (lint vars t)
  | Fn("-",[s;t]) -> linear_sub vars (lint vars s) (lint vars t)
  | Fn("*",[s;t]) ->
        let s' = lint vars s and t' = lint vars t in
        if is_numeral s' then linear_cmul (dest_numeral s') t'
        else if is_numeral t' then linear_cmul (dest_numeral t') s'
        else failwith "lint: apparent nonlinearity"
  | _ -> if is_numeral tm then tm else failwith "lint: unknown term";;

(* ------------------------------------------------------------------------- *)
(* Linearize the atoms in a formula, and eliminate non-strict inequalities.  *)
(* ------------------------------------------------------------------------- *)

let mkatom vars p t = Atom(R(p,[Fn("0",[]);lint vars t]));;

let linform vars fm =
  match fm with
    Atom(R("divides",[c;t])) ->
        let c' = mk_numeral(abs_num(dest_numeral c)) in
        Atom(R("divides",[c';lint vars t]))
  | Atom(R("=",[s;t])) -> mkatom vars "=" (Fn("-",[t;s]))
  | Atom(R("<",[s;t])) -> mkatom vars "<" (Fn("-",[t;s]))
  | Atom(R(">",[s;t])) -> mkatom vars "<" (Fn("-",[s;t]))
  | Atom(R("<=",[s;t])) ->
        mkatom vars "<" (Fn("-",[Fn("+",[t;Fn("1",[])]);s]))
  | Atom(R(">=",[s;t])) ->
        mkatom vars "<" (Fn("-",[Fn("+",[s;Fn("1",[])]);t]))
  | _ -> fm;;

(* ------------------------------------------------------------------------- *)
(* Post-NNF transformation eliminating negated inequalities.                 *)
(* ------------------------------------------------------------------------- *)

let rec posineq fm =
  match fm with
  | Not(Atom(R("<",[Fn("0",[]); t]))) ->
        Atom(R("<",[Fn("0",[]); linear_sub [] (Fn("1",[])) t]))
  | _ -> fm;;

(* ------------------------------------------------------------------------- *)
(* Find the LCM of the coefficients of x.                                    *)
(* ------------------------------------------------------------------------- *)

let rec formlcm x fm =
  match fm with
    Atom(R(p,[_;Fn("+",[Fn("*",[c;y]);z])])) when y = x ->
        abs_num(dest_numeral c)
  | Not(p) -> formlcm x p
  | And(p,q) -> lcm_num (formlcm x p) (formlcm x q)
  | Or(p,q) -> lcm_num (formlcm x p) (formlcm x q)
  | _ -> Int 1;;

(* ------------------------------------------------------------------------- *)
(* Adjust all coefficients of x in formula; fold in reduction to +/- 1.      *)
(* ------------------------------------------------------------------------- *)

let rec adjustcoeff x l fm =
  match fm with
    Atom(R(p,[d; Fn("+",[Fn("*",[c;y]);z])])) when y = x ->
        let m = l // dest_numeral c in
        let n = if p = "<" then abs_num(m) else m in
        let xtm = Fn("*",[mk_numeral(m // n); x]) in
        Atom(R(p,[linear_cmul (abs_num m) d;
                Fn("+",[xtm; linear_cmul n z])]))
  | Not(p) -> Not(adjustcoeff x l p)
  | And(p,q) -> And(adjustcoeff x l p,adjustcoeff x l q)
  | Or(p,q) -> Or(adjustcoeff x l p,adjustcoeff x l q)
  | _ -> fm;;

(* ------------------------------------------------------------------------- *)
(* Hence make coefficient of x one in existential formula.                   *)
(* ------------------------------------------------------------------------- *)

let unitycoeff x fm =
  let l = formlcm x fm in
  let fm' = adjustcoeff x l fm in
  if l =/ Int 1 then fm' else
  let xp = Fn("+",[Fn("*",[Fn("1",[]);x]); Fn("0",[])]) in
  And(Atom(R("divides",[mk_numeral l; xp])),adjustcoeff x l fm);;

(* ------------------------------------------------------------------------- *)
(* The "minus infinity" version.                                             *)
(* ------------------------------------------------------------------------- *)

let rec minusinf x fm =
  match fm with
    Atom(R("=",[Fn("0",[]); Fn("+",[Fn("*",[Fn("1",[]);y]);z])]))
        when y = x -> False
  | Atom(R("<",[Fn("0",[]); Fn("+",[Fn("*",[pm1;y]);z])])) when y = x ->
        if pm1 = Fn("1",[]) then False else True
  | Not(p) -> Not(minusinf x p)
  | And(p,q) -> And(minusinf x p,minusinf x q)
  | Or(p,q) -> Or(minusinf x p,minusinf x q)
  | _ -> fm;;

(* ------------------------------------------------------------------------- *)
(* The LCM of all the divisors that involve x.                               *)
(* ------------------------------------------------------------------------- *)

let rec divlcm x fm =
  match fm with
    Atom(R("divides",[d;Fn("+",[Fn("*",[c;y]);z])])) when y = x ->
        dest_numeral d
  | Not(p) -> divlcm x p
  | And(p,q) -> lcm_num (divlcm x p) (divlcm x q)
  | Or(p,q) -> lcm_num (divlcm x p) (divlcm x q)
  | _ -> Int 1;;

(* ------------------------------------------------------------------------- *)
(* Construct the B-set.                                                      *)
(* ------------------------------------------------------------------------- *)

let rec bset x fm =
  match fm with
    Not(Atom(R("=",[Fn("0",[]); Fn("+",[Fn("*",[Fn("1",[]);y]);a])])))
    when y = x -> [linear_neg a]
  | Atom(R("=",[Fn("0",[]); Fn("+",[Fn("*",[Fn("1",[]);y]);a])]))
    when y = x -> [linear_neg(linear_add [] a (Fn("1",[])))]
  | Atom(R("<",[Fn("0",[]); Fn("+",[Fn("*",[Fn("1",[]);y]);a])]))
    when y = x -> [linear_neg a]
  | Not(p) -> bset x p
  | And(p,q) -> union (bset x p) (bset x q)
  | Or(p,q) -> union (bset x p) (bset x q)
  | _ -> [];;

(* ------------------------------------------------------------------------- *)
(* Replace top variable with another linear form, retaining canonicality.    *)
(* ------------------------------------------------------------------------- *)

let rec linrep vars x t fm =
  match fm with
    Atom(R(p,[d; Fn("+",[Fn("*",[c;y]);z])])) when y = x ->
        let ct = linear_cmul (dest_numeral c) t in
        Atom(R(p,[d; linear_add vars ct z]))
  | Not(p) -> Not(linrep vars x t p)
  | And(p,q) -> And(linrep vars x t p,linrep vars x t q)
  | Or(p,q) -> Or(linrep vars x t p,linrep vars x t q)
  | _ -> fm;;

(* ------------------------------------------------------------------------- *)
(* Evaluation of constant expressions.                                       *)
(* ------------------------------------------------------------------------- *)

let operations =
  ["=",(=/); "<",(</); ">",(>/); "<=",(<=/); ">=",(>=/);
   "divides",(fun x y -> mod_num y x =/ Int 0)];;

let evalc_atom at =
  match at with
    R(p,[s;t]) ->
        (try if assoc p operations (dest_numeral s) (dest_numeral t)
             then True else False
         with Failure _ -> Atom at)
  | _ -> Atom at;;

let evalc = onatoms evalc_atom;;

(* ------------------------------------------------------------------------- *)
(* Hence the core quantifier elimination procedure.                          *)
(* ------------------------------------------------------------------------- *)

let cooper vars fm =
  match fm with
   Exists(x0,p0) ->
        let x = Var x0 in
        let p = unitycoeff x p0 in
        let p_inf = simplify(minusinf x p) and bs = bset x p
        and js = Int 1 --- divlcm x p in
        let p_element j b =
          linrep vars x (linear_add vars b (mk_numeral j)) p in
        let stage j = list_disj
           (linrep vars x (mk_numeral j) p_inf ::
            map (p_element j) bs) in
        list_disj (map stage js)
  | _ -> failwith "cooper: not an existential formula";;

(* ------------------------------------------------------------------------- *)
(* Overall function.                                                         *)
(* ------------------------------------------------------------------------- *)

let integer_qelim =
  simplify ** evalc **
  lift_qelim linform (cnnf posineq ** evalc) cooper;;

(* ------------------------------------------------------------------------- *)
(* Examples.                                                                 *)
(* ------------------------------------------------------------------------- *)

(*
integer_qelim <<forall x y. x < y ==> 2 * x + 1 < 2 * y>>;;

integer_qelim <<forall x y. ~(2 * x + 1 = 2 * y)>>;;

integer_qelim <<exists x y. x > 0 /\ y >= 0 /\ 3 * x - 5 * y = 1>>;;

integer_qelim <<exists x y z. 4 * x - 6 * y = 1>>;;

integer_qelim <<forall x. b < x ==> a <= x>>;;

integer_qelim <<forall x. a < 3 * x ==> b < 3 * x>>;;

time integer_qelim <<forall x y. x <= y ==> 2 * x + 1 < 2 * y>>;;

time integer_qelim <<(exists d. y = 65 * d) ==> (exists d. y = 5 * d)>>;;

time integer_qelim
  <<forall y. (exists d. y = 65 * d) ==> (exists d. y = 5 * d)>>;;

time integer_qelim <<forall x y. ~(2 * x + 1 = 2 * y)>>;;

time integer_qelim
  <<forall x y z. (2 * x + 1 = 2 * y) ==> x + y + z > 129>>;;

time integer_qelim <<forall x. a < x ==> b < x>>;;

time integer_qelim <<forall x. a <= x ==> b < x>>;;

(* ------------------------------------------------------------------------- *)
(* Formula examples from Cooper's paper.                                     *)
(* ------------------------------------------------------------------------- *)

time integer_qelim <<forall a b. exists x. a < 20 * x /\ 20 * x < b>>;;

time integer_qelim <<exists x. a < 20 * x /\ 20 * x < b>>;;

time integer_qelim <<forall b. exists x. a < 20 * x /\ 20 * x < b>>;;

time integer_qelim
  <<forall a. exists b. a < 4 * b + 3 * a \/ (~(a < b) /\ a > b + 1)>>;;

time integer_qelim
  <<exists y. forall x. x + 5 * y > 1 /\ 13 * x - y > 1 /\ x + 2 < 0>>;;

(* ------------------------------------------------------------------------- *)
(* Some more.                                                                *)
(* ------------------------------------------------------------------------- *)

time integer_qelim <<forall x y. x >= 0 /\ y >= 0
                  ==> 12 * x - 8 * y < 0 \/ 12 * x - 8 * y > 2>>;;

time integer_qelim <<exists x y. 5 * x + 3 * y = 1>>;;

time integer_qelim <<exists x y. 5 * x + 10 * y = 1>>;;

time integer_qelim <<exists x y. x >= 0 /\ y >= 0 /\ 5 * x - 6 * y = 1>>;;


time integer_qelim <<exists w x y z. 2 * w + 3 * x + 4 * y + 5 * z = 1>>;;

time integer_qelim <<exists x y. x >= 0 /\ y >= 0 /\ 5 * x - 3 * y = 1>>;;

time integer_qelim <<exists x y. x >= 0 /\ y >= 0 /\ 3 * x - 5 * y = 1>>;;

time integer_qelim <<exists x y. x >= 0 /\ y >= 0 /\ 6 * x - 3 * y = 1>>;;

time integer_qelim
  <<forall x y. ~(x = 0) ==> 5 * y < 6 * x \/ 5 * y > 6 * x>>;;

time integer_qelim
  <<forall x y. ~divides(5,x) /\ ~divides(6,y) ==> ~(6 * x = 5 * y)>>;;

time integer_qelim <<forall x y. ~divides(5,x) ==> ~(6 * x = 5 * y)>>;;

time integer_qelim <<forall x y. ~(6 * x = 5 * y)>>;;

time integer_qelim <<forall x y. 6 * x = 5 * y ==> exists d. y = 3 * d>>;;

time integer_qelim <<6 * x = 5 * y ==> exists d. y = 3 * d>>;;

(* ------------------------------------------------------------------------- *)
(* Positive variant of the Bezout theorem.                                   *)
(* ------------------------------------------------------------------------- *)

time integer_qelim
  <<forall z. z > 7 ==> exists x y. x >= 0 /\ y >= 0 /\ 3 * x + 5 * y = z>>;;

time integer_qelim
  <<forall z. z > 2 ==> exists x y. x >= 0 /\ y >= 0 /\ 3 * x + 5 * y = z>>;;

time integer_qelim
  <<forall z.
        z <= 7
        ==> ((exists x y. x >= 0 /\ y >= 0 /\ 3 * x + 5 * y = z) <=>
             ~(exists x y. x >= 0 /\ y >= 0 /\ 3 * x + 5 * y = 7 - z))>>;;

(* ------------------------------------------------------------------------- *)
(* Basic result about congruences.                                           *)
(* ------------------------------------------------------------------------- *)

time integer_qelim
  <<forall x. ~divides(2,x) /\ divides(3,x-1) <=>
              divides(12,x-1) \/ divides(12,x-7)>>;;

time integer_qelim
  <<forall x. ~(exists m. x = 2 * m) /\ (exists m. x = 3 * m + 1) <=>
              (exists m. x = 12 * m + 1) \/ (exists m. x = 12 * m + 7)>>;;

(* ------------------------------------------------------------------------- *)
(* Something else.                                                           *)
(* ------------------------------------------------------------------------- *)

time integer_qelim
 <<forall x.
        ~(divides(2,x))
        ==> divides(4,x-1) \/
            divides(8,x-1) \/
            divides(8,x-3) \/
            divides(6,x-1) \/
            divides(14,x-1) \/
            divides(14,x-9) \/
            divides(14,x-11) \/
            divides(24,x-5) \/
            divides(24,x-11)>>;;

(* ------------------------------------------------------------------------- *)
(* Testing fix for an earlier version.                                       *)
(* ------------------------------------------------------------------------- *)

(integer_qelim ** generalize)
  <<a + 2 = b /\ v_3 = b - a + 1 /\ v_2 = b - 2 /\ v_1 = 3 ==> false>>;;

(* ------------------------------------------------------------------------- *)
(* Inspired by the Collatz conjecture.                                       *)
(* ------------------------------------------------------------------------- *)

integer_qelim
  <<exists a b. ~(a = 1) /\ ((2 * b = a) \/ (2 * b = 3 * a + 1)) /\ (a = b)>>;;

integer_qelim
 <<exists a b. a > 1 /\ b > 1 /\
               ((2 * b = a) \/ (2 * b = 3 * a + 1)) /\
               (a = b)>>;;

*)
(* ========================================================================= *)
(* Complex quantifier elimination (by simple divisibility a la Tarski).      *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

(* ------------------------------------------------------------------------- *)
(* Basic arithmetic operations on canonical polynomials.                     *)
(* ------------------------------------------------------------------------- *)

let rec poly_add vars pol1 pol2 =
  match (pol1,pol2) with
   (Fn("+",[c; Fn("*",[Var x; p])]),Fn("+",[d; Fn("*",[Var y; q])])) ->
        if earlier vars x y then poly_cadd vars pol2 pol1
        else if earlier vars y x then poly_cadd vars pol1 pol2 else
        let cd = poly_add vars c d and pq = poly_add vars p q in
        if pq = Fn("0",[]) then cd
        else Fn("+",[cd; Fn("*",[Var x; pq])])
    | (_,Fn("+",_)) -> poly_cadd vars pol1 pol2
    | (Fn("+",_),pol2) -> poly_cadd vars pol2 pol1
    | _ -> numeral2 (+/) pol1 pol2
and poly_cadd vars =
  fun pol1 (Fn("+",[d; Fn("*",[Var y; q])])) ->
        Fn("+",[poly_add vars pol1 d; Fn("*",[Var y; q])]);;

let rec poly_neg =
  function (Fn("+",[c; Fn("*",[Var x; p])])) ->
                Fn("+",[poly_neg c; Fn("*",[Var x; poly_neg p])])
         | n -> numeral1 minus_num n;;

let poly_sub vars p q = poly_add vars p (poly_neg q);;

let rec poly_mul vars pol1 pol2 =
  match (pol1,pol2) with
   (Fn("+",[c; Fn("*",[Var x; p])]),Fn("+",[d; Fn("*",[Var y; q])])) ->
        if earlier vars x y then poly_cmul vars pol2 pol1
        else poly_cmul vars pol1 pol2
  | (Fn("0",[]),_) -> Fn("0",[])
  | (_,Fn("0",[])) -> Fn("0",[])
  | (_,Fn("+",_)) -> poly_cmul vars pol1 pol2
  | (Fn("+",_),_) -> poly_cmul vars pol2 pol1
  | _ -> numeral2 ( */ ) pol1 pol2
and poly_cmul vars =
  fun pol1 (Fn("+",[d; Fn("*",[Var y; q])])) ->
        poly_add vars (poly_mul vars pol1 d)
                     (Fn("+",[Fn("0",[]);
                              Fn("*",[Var y; poly_mul vars pol1 q])]));;

let poly_pow vars p n = funpow n (poly_mul vars p) (Fn("1",[]));;

(* ------------------------------------------------------------------------- *)
(* Convert term into canonical polynomial representative.                    *)
(* ------------------------------------------------------------------------- *)

let rec polynate vars tm =
  match tm with
    Var x -> Fn("+",[Fn("0",[]); Fn("*",[Var x; Fn("1",[])])])
  | Fn(ns,[]) -> if can num_of_string ns then tm
                 else failwith "Unexpected constant"
  | Fn("^",[p;Fn(n,[])]) ->
       poly_pow vars (polynate vars p) (int_of_string n)
  | Fn("*",[s;t]) -> poly_mul vars (polynate vars s) (polynate vars t)
  | Fn("+",[s;t]) -> poly_add vars (polynate vars s) (polynate vars t)
  | Fn("-",[s;t]) -> poly_sub vars (polynate vars s) (polynate vars t)
  | Fn("-",[t]) -> poly_neg (polynate vars t)
  | Fn(s,_) -> failwith ("Unexpected function symbol: "^s);;

(* ------------------------------------------------------------------------- *)
(* Do likewise for atom so the RHS is zero.                                  *)
(* ------------------------------------------------------------------------- *)

let polyatom vars fm =
  match fm with
    Atom(R(a,[s;t])) ->
        Atom(R(a,[polynate vars (Fn("-",[s;t])); Fn("0",[])]))
  | _ -> failwith "polyatom: not an atom";;

(* ------------------------------------------------------------------------- *)
(* Sanity check.                                                             *)
(* ------------------------------------------------------------------------- *)

(*
let liouville = polyatom ["w"; "x"; "y"; "z"]
 <<6 * (w^2 + x^2 + y^2 + z^2)^2 =
   (((w + x)^4 + (w + y)^4 + (w + z)^4 +
     (x + y)^4 + (x + z)^4 + (y + z)^4) +
    ((w - x)^4 + (w - y)^4 + (w - z)^4 +
     (x - y)^4 + (x - z)^4 + (y - z)^4))>>;;
*)

(* ------------------------------------------------------------------------- *)
(* Useful utility functions for polynomial terms.                            *)
(* ------------------------------------------------------------------------- *)

let rec degree vars =
  function (Fn("+",[c; Fn("*",[Var x; p])])) when x = hd vars ->
                1 + degree vars p
         | _ -> 0;;

let rec coefficients vars =
  function (Fn("+",[c; Fn("*",[Var x; q])]) as p) when x = hd vars ->
                c::(coefficients vars q)
         | p -> [p];;

let head vars p = last(coefficients vars p);;

let is_constant vars p =
  match p with
    Fn("+",[c; Fn("*",[Var x; q])]) when x = hd vars -> false
  | _ -> true;;

let rec behead vars =
  function Fn("+",[c; Fn("*",[Var x; p])]) when x = hd vars ->
                let p' = behead vars p in
                if p' = Fn("0",[]) then c
                else Fn("+",[c; Fn("*",[Var x; p'])])
         | _ -> Fn("0",[]);;

(* ------------------------------------------------------------------------- *)
(* Get the constant multiple of the "maximal" monomial (implicit lex order)  *)
(* ------------------------------------------------------------------------- *)

let rec headconst p =
  match p with
    Fn("+",[c; Fn("*",[Var x; q])]) -> headconst q
  | Fn(n,[]) -> dest_numeral p;;

(* ------------------------------------------------------------------------- *)
(* Make a polynomial monic and return negativity flag for head constant      *)
(* ------------------------------------------------------------------------- *)

let monic vars p =
  let h = headconst p in
  if h = Int 0 then p,false else
  poly_mul vars (mk_numeral(Int 1 // h)) p,h </ Int 0;;

(* ------------------------------------------------------------------------- *)
(* Pseudo-division of s by p; head coefficient of p assumed nonzero.         *)
(* Returns (k,r) so that a^k s = p q + r for some q, deg(r) < deg(p).        *)
(* Optimized only for the trivial case of equal head coefficients; no GCDs.  *)
(* ------------------------------------------------------------------------- *)

let pdivide =
  let shift1 x p = Fn("+",[Fn("0",[]); Fn("*",[Var x; p])]) in
  let rec pdivide_aux vars a n p k s =
    if s = Fn("0",[]) then (k,s) else
    let b = head vars s and m = degree vars s in
    if m < n then (k,s) else
    let p' = funpow (m - n) (shift1 (hd vars)) p in
    if a = b then
      pdivide_aux vars a n p k (poly_sub vars s p')
    else
      pdivide_aux vars a n p (k+1)
        (poly_sub vars (poly_mul vars a s) (poly_mul vars b p')) in
  fun vars s p -> pdivide_aux vars (head vars p) (degree vars p) p 0 s;;

(* ------------------------------------------------------------------------- *)
(* Datatype of signs.                                                        *)
(* ------------------------------------------------------------------------- *)

type sign = Zero | Nonzero | Positive | Negative;;

let swap swf s =
  if not swf then s else
  match s with
    Positive -> Negative
  | Negative -> Positive
  | _ -> s;;

(* ------------------------------------------------------------------------- *)
(* Lookup and asserting of polynomial sign, modulo constant multiples.       *)
(* Note that we are building in a characteristic-zero assumption here.       *)
(* ------------------------------------------------------------------------- *)

let findsign vars sgns p =
  try let p',swf = monic vars p in
      swap swf (assoc p' sgns)
  with Failure _ -> failwith "findsign";;

let assertsign vars sgns (p,s) =
  if p = Fn("0",[]) then
    if s = Zero then sgns else failwith "assertsign"
  else
    let p',swf = monic vars p in
    let s' = swap swf s in
    let s0 = try assoc p' sgns with Failure _ -> s' in
    if s' = s0 or s0 = Nonzero & (s' = Positive or s' = Negative)
    then (p',s')::(subtract sgns [p',s0]) else failwith "assertsign";;

(* ------------------------------------------------------------------------- *)
(* Deduce or case-split over zero status of polynomial.                      *)
(* ------------------------------------------------------------------------- *)

let split_zero vars sgns pol cont_z cont_n =
  try let z = findsign vars sgns pol in
      (if z = Zero then cont_z else cont_n) sgns
  with Failure "findsign" ->
      let eq = Atom(R("=",[pol; Fn("0",[])])) in
      Or(And(eq,cont_z (assertsign vars sgns (pol,Zero))),
         And(Not eq,cont_n (assertsign vars sgns (pol,Nonzero))));;

(* ------------------------------------------------------------------------- *)
(* Whether a polynomial is nonzero in a context.                             *)
(* ------------------------------------------------------------------------- *)

let poly_nonzero vars sgns pol =
  let cs = coefficients vars pol in
  let dcs,ucs = partition (can (findsign vars sgns)) cs in
  if exists (fun p -> not(findsign vars sgns p = Zero)) dcs then True
  else if ucs = [] then False else
  end_itlist (fun p q -> Or(p,q))
             (map (fun p -> Not(Atom(R("=",[p; Fn("0",[])])))) ucs);;

(* ------------------------------------------------------------------------- *)
(* Divisibility and hence variety inclusion.                                 *)
(* ------------------------------------------------------------------------- *)

let rec poly_not_divides vars sgns p q =
  if degree vars q < degree vars p then poly_nonzero vars sgns q else
  let _,q' = pdivide vars q p in
  poly_not_divides vars sgns p q';;

let poly_variety vars sgns p q =
  poly_not_divides vars sgns p (poly_pow vars q (degree vars p));;

(* ------------------------------------------------------------------------- *)
(* Main reduction for ?x. all ceqs == 0 and all cneqs =/= 0, in context.     *)
(* ------------------------------------------------------------------------- *)

let rec reduce vars (eqs,neqs) sgns =
  try let c = find (is_constant vars) eqs in
      try if findsign vars sgns c = Zero
          then reduce vars (subtract eqs [c],neqs) sgns else False
      with Failure _ ->
          And(Atom(R("=",[c;Fn("0",[])])),
              reduce vars (subtract eqs [c],neqs)
                          (assertsign vars sgns (c,Zero)))
  with Failure _ -> match (eqs,neqs) with
    ([],neqs) -> list_conj (map (poly_nonzero vars sgns) neqs)
  | ([p],neqs) ->
        split_zero vars sgns (head vars p)
          (reduce vars ([behead vars p],neqs))
          (fun sgns ->
             if neqs = [] then True else
             poly_variety vars sgns p
              (snd(pdivide vars (end_itlist (poly_mul vars) neqs) p)))
  | (_,_) ->
        let n = end_itlist min (map (degree vars) eqs) in
        let p = find (fun p -> degree vars p = n) eqs in
        let oeqs = subtract eqs [p] in
        let cfn q = snd(pdivide vars q p) in
        split_zero vars sgns (head vars p)
          (reduce vars (behead vars p::oeqs,neqs))
          (reduce vars (p::(map cfn oeqs),neqs));;

(* ------------------------------------------------------------------------- *)
(* Basic complex quantifier elimination on actual existential formula.       *)
(* ------------------------------------------------------------------------- *)

let lhz (Atom(R("=",[s; Fn("0",[])]))) = s;;

let lhnz (Not fm) = lhz fm;;

let basic_complex_qelim vars fm =
  match fm with
    Exists(x,p) ->
        let eqs,neqs = partition (non negative) (conjuncts p) in
        reduce (x::vars) (map lhz eqs,map lhnz neqs)
               [Fn("1",[]),Positive]
  | _ -> failwith "basic_complex_qelim: not an existential formula";;

(* ------------------------------------------------------------------------- *)
(* Full quantifier elimination.                                              *)
(* ------------------------------------------------------------------------- *)

let complex_qelim =
  simplify ** evalc **
  lift_qelim polyatom (dnf ** cnnf (fun x -> x) ** evalc)
             basic_complex_qelim;;

(* ------------------------------------------------------------------------- *)
(* Examples.                                                                 *)
(* ------------------------------------------------------------------------- *)

(*
complex_qelim
 <<forall a x. a^2 = 2 /\ x^2 + a*x + 1 = 0 ==> x^4 + 1 = 0>>;;

complex_qelim
 <<forall a x. a^2 = 2 /\ x^2 + a*x + 1 = 0 ==> x^4 + c = 0>>;;

complex_qelim
 <<forall x y. x^2 = 2 /\ y^2 = 3 ==> (x * y)^2 = 6>>;;

complex_qelim
 <<forall a b c x y. (a * x^2 + b * x + c = 0) /\
                     (a * y^2 + b * y + c = 0) /\
                     ~(x = y)
                     ==> (a * x * y = c) /\ (a * (x + y) + b = 0)>>;;

complex_qelim
 <<forall x y.
    (forall a b c. (a * x^2 + b * x + c = 0) /\
                   (a * y^2 + b * y + c = 0)
                   ==> (a * x * y = c) /\ (a * (x + y) + b = 0))
    <=> ~(x = y)>>;;

(* ------------------------------------------------------------------------- *)
(* More serious test for pure normalization code.                            *)
(* ------------------------------------------------------------------------- *)

let polytest tm = time (polynate (fvt tm)) tm;;

let lagrange_4 = polytest
 <<|(((x1^2) + (x2^2) + (x3^2) + (x4^2)) *
     ((y1^2) + (y2^2) + (y3^2) + (y4^2))) -
    ((((((x1*y1) - (x2*y2)) - (x3*y3)) - (x4*y4))^2)  +
     (((((x1*y2) + (x2*y1)) + (x3*y4)) - (x4*y3))^2)  +
     (((((x1*y3) - (x2*y4)) + (x3*y1)) + (x4*y2))^2)  +
     (((((x1*y4) + (x2*y3)) - (x3*y2)) + (x4*y1))^2))|>>;;

let lagrange_8 = polytest
 <<|((p1^2 + q1^2 + r1^2 + s1^2 + t1^2 + u1^2 + v1^2 + w1^2) *
     (p2^2 + q2^2 + r2^2 + s2^2 + t2^2 + u2^2 + v2^2 + w2^2)) -
     ((p1 * p2 - q1 * q2 - r1 * r2 - s1 * s2 - t1 * t2 - u1 * u2 - v1 * v2 - w1* w2)^2 +
      (p1 * q2 + q1 * p2 + r1 * s2 - s1 * r2 + t1 * u2 - u1 * t2 - v1 * w2 + w1* v2)^2 +
      (p1 * r2 - q1 * s2 + r1 * p2 + s1 * q2 + t1 * v2 + u1 * w2 - v1 * t2 - w1* u2)^2 +
      (p1 * s2 + q1 * r2 - r1 * q2 + s1 * p2 + t1 * w2 - u1 * v2 + v1 * u2 - w1* t2)^2 +
      (p1 * t2 - q1 * u2 - r1 * v2 - s1 * w2 + t1 * p2 + u1 * q2 + v1 * r2 + w1* s2)^2 +
      (p1 * u2 + q1 * t2 - r1 * w2 + s1 * v2 - t1 * q2 + u1 * p2 - v1 * s2 + w1* r2)^2 +
      (p1 * v2 + q1 * w2 + r1 * t2 - s1 * u2 - t1 * r2 + u1 * s2 + v1 * p2 - w1* q2)^2 +
      (p1 * w2 - q1 * v2 + r1 * u2 + s1 * t2 - t1 * s2 - u1 * r2 + v1 * q2 + w1* p2)^2)|>>;;

let liouville = polytest
 <<|6 * (x1^2 + x2^2 + x3^2 + x4^2)^2 -
    (((x1 + x2)^4 + (x1 + x3)^4 + (x1 + x4)^4 +
      (x2 + x3)^4 + (x2 + x4)^4 + (x3 + x4)^4) +
     ((x1 - x2)^4 + (x1 - x3)^4 + (x1 - x4)^4 +
      (x2 - x3)^4 + (x2 - x4)^4 + (x3 - x4)^4))|>>;;

let fleck = polytest
 <<|60 * (x1^2 + x2^2 + x3^2 + x4^2)^3 -
    (((x1 + x2 + x3)^6 + (x1 + x2 - x3)^6 +
      (x1 - x2 + x3)^6 + (x1 - x2 - x3)^6 +
      (x1 + x2 + x4)^6 + (x1 + x2 - x4)^6 +
      (x1 - x2 + x4)^6 + (x1 - x2 - x4)^6 +
      (x1 + x3 + x4)^6 + (x1 + x3 - x4)^6 +
      (x1 - x3 + x4)^6 + (x1 - x3 - x4)^6 +
      (x2 + x3 + x4)^6 + (x2 + x3 - x4)^6 +
      (x2 - x3 + x4)^6 + (x2 - x3 - x4)^6) +
     2 * ((x1 + x2)^6 + (x1 - x2)^6 +
          (x1 + x3)^6 + (x1 - x3)^6 +
          (x1 + x4)^6 + (x1 - x4)^6 +
          (x2 + x3)^6 + (x2 - x3)^6 +
          (x2 + x4)^6 + (x2 - x4)^6 +
          (x3 + x4)^6 + (x3 - x4)^6) +
     36 * (x1^6 + x2^6 + x3^6 + x4^6))|>>;;

let hurwitz = polytest
 <<|5040 * (x1^2 + x2^2 + x3^2 + x4^2)^4 -
    (6 * ((x1 + x2 + x3 + x4)^8 +
          (x1 + x2 + x3 - x4)^8 +
          (x1 + x2 - x3 + x4)^8 +
          (x1 + x2 - x3 - x4)^8 +
          (x1 - x2 + x3 + x4)^8 +
          (x1 - x2 + x3 - x4)^8 +
          (x1 - x2 - x3 + x4)^8 +
          (x1 - x2 - x3 - x4)^8) +
     ((2 * x1 + x2 + x3)^8 +
      (2 * x1 + x2 - x3)^8 +
      (2 * x1 - x2 + x3)^8 +
      (2 * x1 - x2 - x3)^8 +
      (2 * x1 + x2 + x4)^8 +
      (2 * x1 + x2 - x4)^8 +
      (2 * x1 - x2 + x4)^8 +
      (2 * x1 - x2 - x4)^8 +
      (2 * x1 + x3 + x4)^8 +
      (2 * x1 + x3 - x4)^8 +
      (2 * x1 - x3 + x4)^8 +
      (2 * x1 - x3 - x4)^8 +
      (2 * x2 + x3 + x4)^8 +
      (2 * x2 + x3 - x4)^8 +
      (2 * x2 - x3 + x4)^8 +
      (2 * x2 - x3 - x4)^8 +
      (x1 + 2 * x2 + x3)^8 +
      (x1 + 2 * x2 - x3)^8 +
      (x1 - 2 * x2 + x3)^8 +
      (x1 - 2 * x2 - x3)^8 +
      (x1 + 2 * x2 + x4)^8 +
      (x1 + 2 * x2 - x4)^8 +
      (x1 - 2 * x2 + x4)^8 +
      (x1 - 2 * x2 - x4)^8 +
      (x1 + 2 * x3 + x4)^8 +
      (x1 + 2 * x3 - x4)^8 +
      (x1 - 2 * x3 + x4)^8 +
      (x1 - 2 * x3 - x4)^8 +
      (x2 + 2 * x3 + x4)^8 +
      (x2 + 2 * x3 - x4)^8 +
      (x2 - 2 * x3 + x4)^8 +
      (x2 - 2 * x3 - x4)^8 +
      (x1 + x2 + 2 * x3)^8 +
      (x1 + x2 - 2 * x3)^8 +
      (x1 - x2 + 2 * x3)^8 +
      (x1 - x2 - 2 * x3)^8 +
      (x1 + x2 + 2 * x4)^8 +
      (x1 + x2 - 2 * x4)^8 +
      (x1 - x2 + 2 * x4)^8 +
      (x1 - x2 - 2 * x4)^8 +
      (x1 + x3 + 2 * x4)^8 +
      (x1 + x3 - 2 * x4)^8 +
      (x1 - x3 + 2 * x4)^8 +
      (x1 - x3 - 2 * x4)^8 +
      (x2 + x3 + 2 * x4)^8 +
      (x2 + x3 - 2 * x4)^8 +
      (x2 - x3 + 2 * x4)^8 +
      (x2 - x3 - 2 * x4)^8) +
     60 * ((x1 + x2)^8 + (x1 - x2)^8 +
           (x1 + x3)^8 + (x1 - x3)^8 +
           (x1 + x4)^8 + (x1 - x4)^8 +
           (x2 + x3)^8 + (x2 - x3)^8 +
           (x2 + x4)^8 + (x2 - x4)^8 +
           (x3 + x4)^8 + (x3 - x4)^8) +
     6 * ((2 * x1)^8 + (2 * x2)^8 + (2 * x3)^8 + (2 * x4)^8))|>>;;

let schur = polytest
 <<|22680 * (x1^2 + x2^2 + x3^2 + x4^2)^5 -
    (9 * ((2 * x1)^10 +
          (2 * x2)^10 +
          (2 * x3)^10 +
          (2 * x4)^10) +
     180 * ((x1 + x2)^10 + (x1 - x2)^10 +
            (x1 + x3)^10 + (x1 - x3)^10 +
            (x1 + x4)^10 + (x1 - x4)^10 +
            (x2 + x3)^10 + (x2 - x3)^10 +
            (x2 + x4)^10 + (x2 - x4)^10 +
            (x3 + x4)^10 + (x3 - x4)^10) +
     ((2 * x1 + x2 + x3)^10 +
      (2 * x1 + x2 - x3)^10 +
      (2 * x1 - x2 + x3)^10 +
      (2 * x1 - x2 - x3)^10 +
      (2 * x1 + x2 + x4)^10 +
      (2 * x1 + x2 - x4)^10 +
      (2 * x1 - x2 + x4)^10 +
      (2 * x1 - x2 - x4)^10 +
      (2 * x1 + x3 + x4)^10 +
      (2 * x1 + x3 - x4)^10 +
      (2 * x1 - x3 + x4)^10 +
      (2 * x1 - x3 - x4)^10 +
      (2 * x2 + x3 + x4)^10 +
      (2 * x2 + x3 - x4)^10 +
      (2 * x2 - x3 + x4)^10 +
      (2 * x2 - x3 - x4)^10 +
      (x1 + 2 * x2 + x3)^10 +
      (x1 + 2 * x2 - x3)^10 +
      (x1 - 2 * x2 + x3)^10 +
      (x1 - 2 * x2 - x3)^10 +
      (x1 + 2 * x2 + x4)^10 +
      (x1 + 2 * x2 - x4)^10 +
      (x1 - 2 * x2 + x4)^10 +
      (x1 - 2 * x2 - x4)^10 +
      (x1 + 2 * x3 + x4)^10 +
      (x1 + 2 * x3 - x4)^10 +
      (x1 - 2 * x3 + x4)^10 +
      (x1 - 2 * x3 - x4)^10 +
      (x2 + 2 * x3 + x4)^10 +
      (x2 + 2 * x3 - x4)^10 +
      (x2 - 2 * x3 + x4)^10 +
      (x2 - 2 * x3 - x4)^10 +
      (x1 + x2 + 2 * x3)^10 +
      (x1 + x2 - 2 * x3)^10 +
      (x1 - x2 + 2 * x3)^10 +
      (x1 - x2 - 2 * x3)^10 +
      (x1 + x2 + 2 * x4)^10 +
      (x1 + x2 - 2 * x4)^10 +
      (x1 - x2 + 2 * x4)^10 +
      (x1 - x2 - 2 * x4)^10 +
      (x1 + x3 + 2 * x4)^10 +
      (x1 + x3 - 2 * x4)^10 +
      (x1 - x3 + 2 * x4)^10 +
      (x1 - x3 - 2 * x4)^10 +
      (x2 + x3 + 2 * x4)^10 +
      (x2 + x3 - 2 * x4)^10 +
      (x2 - x3 + 2 * x4)^10 +
      (x2 - x3 - 2 * x4)^10) +
     9 * ((x1 + x2 + x3 + x4)^10 +
          (x1 + x2 + x3 - x4)^10 +
          (x1 + x2 - x3 + x4)^10 +
          (x1 + x2 - x3 - x4)^10 +
          (x1 - x2 + x3 + x4)^10 +
          (x1 - x2 + x3 - x4)^10 +
          (x1 - x2 - x3 + x4)^10 +
          (x1 - x2 - x3 - x4)^10))|>>;;

(* ------------------------------------------------------------------------- *)
(* More non-trivial complex quantifier elimination.                          *)
(* ------------------------------------------------------------------------- *)

let complex_qelim_all = time complex_qelim ** generalize;;

time complex_qelim <<exists x. x + 2 = 3>>;;

time complex_qelim <<exists x. x^2 + a = 3>>;;

time complex_qelim <<exists x. x^2 + x + 1 = 0>>;;

time complex_qelim <<exists x. x^2 + x + 1 = 0 /\ x^3 + x^2 + 1 = 0>>;;

time complex_qelim <<exists x. x^2 + 1 = 0 /\ x^4 + x^3 + x^2 + x = 0>>;;

time complex_qelim 
  <<forall a x. a^2 = 2 /\ x^2 + a*x + 1 = 0 ==> x^4 + 1 = 0>>;;

time complex_qelim 
  <<forall a x. a^2 = 2 /\ x^2 + a*x + 1 = 0 ==> x^4 + 2 = 0>>;;

time complex_qelim 
  <<exists a x. a^2 = 2 /\ x^2 + a*x + 1 = 0 /\ ~(x^4 + 2 = 0)>>;;

time complex_qelim 
  <<exists x. a^2 = 2 /\ x^2 + a*x + 1 = 0 /\ ~(x^4 + 2 = 0)>>;;

time complex_qelim <<forall x. x^2 + a*x + 1 = 0 ==> x^4 + 2 = 0>>;;

time complex_qelim <<forall a. a^2 = 2 /\ x^2 + a*x + 1 = 0 ==> x^4 + 2 = 0>>;;

time complex_qelim <<exists a b c x y.
        a * x^2 + b * x + c = 0 /\
        a * y^2 + b * y + c = 0 /\
        ~(x = y) /\
        ~(a * x * y = c)>>;;

time complex_qelim
 <<forall y_1 y_2 y_3 y_4.
     (y_1 = 2 * y_3) /\
     (y_2 = 2 * y_4) /\
     (y_1 * y_3 = y_2 * y_4)
     ==> (y_1^2 = y_2^2)>>;;

time complex_qelim
 <<forall x y. x^2 = 2 /\ y^2 = 3
         ==> (x * y)^2 = 6>>;;

time complex_qelim
 <<forall x a. (a^2 = 2) /\ (x^2 + a * x + 1 = 0)
         ==> (x^4 + 1 = 0)>>;;

time complex_qelim
 <<forall a x. (a^2 = 2) /\ (x^2 + a * x + 1 = 0)
         ==> (x^4 + 1 = 0)>>;;

time complex_qelim
 <<~(exists a x y. (a^2 = 2) /\
             (x^2 + a * x + 1 = 0) /\
             (y * (x^4 + 1) + 1 = 0))>>;;

time complex_qelim <<forall x. exists y. x^2 = y^3>>;;

time complex_qelim
 <<forall x y z a b. (a + b) * (x - y + z) - (a - b) * (x + y + z) =
               2 * (b * x + b * z - a * y)>>;;

time complex_qelim
 <<forall a b. ~(a = b) ==> exists x y. (y * x^2 = a) /\ (y * x^2 + x = b)>>;;

time complex_qelim
 <<forall a b c x y. (a * x^2 + b * x + c = 0) /\
               (a * y^2 + b * y + c = 0) /\
               ~(x = y)
               ==> (a * x * y = c) /\ (a * (x + y) + b = 0)>>;;

time complex_qelim
 <<~(forall a b c x y. (a * x^2 + b * x + c = 0) /\
                 (a * y^2 + b * y + c = 0)
                 ==> (a * x * y = c) /\ (a * (x + y) + b = 0))>>;;

time complex_qelim
 <<forall y_1 y_2 y_3 y_4.
     (y_1 = 2 * y_3) /\
     (y_2 = 2 * y_4) /\
     (y_1 * y_3 = y_2 * y_4)
     ==> (y_1^2 = y_2^2)>>;;

time complex_qelim
 <<forall a1 b1 c1 a2 b2 c2.
        ~(a1 * b2 = a2 * b1)
        ==> exists x y. (a1 * x + b1 * y = c1) /\ (a2 * x + b2 * y = c2)>>;;

time complex_qelim
 <<~(forall x1 y1 x2 y2 x3 y3.
      exists x0 y0. (x1 - x0)^2 + (y1 - y0)^2 = (x2 - x0)^2 + (y2 - y0)^2 /\
                    (x2 - x0)^2 + (y2 - y0)^2 = (x3 - x0)^2 + (y3 - y0)^2)>>;;

time complex_qelim
 <<forall a b c.
      (exists x y. (a * x^2 + b * x + c = 0) /\
             (a * y^2 + b * y + c = 0) /\
             ~(x = y)) <=>
      (a = 0) /\ (b = 0) /\ (c = 0) \/
      ~(a = 0) /\ ~(b^2 = 4 * a * c)>>;;

time complex_qelim
 <<~(forall x1 y1 x2 y2 x3 y3 x0 y0 x0' y0'.
        (x1 - x0)^2 + (y1 - y0)^2 =
        (x2 - x0)^2 + (y2 - y0)^2 /\
        (x2 - x0)^2 + (y2 - y0)^2 =
        (x3 - x0)^2 + (y3 - y0)^2 /\
        (x1 - x0')^2 + (y1 - y0')^2 =
        (x2 - x0')^2 + (y2 - y0')^2 /\
        (x2 - x0')^2 + (y2 - y0')^2 =
        (x3 - x0')^2 + (y3 - y0')^2
        ==> x0 = x0' /\ y0 = y0')>>;;

time complex_qelim
 <<forall a b c.
        a * x^2 + b * x + c = 0 /\
        a * y^2 + b * y + c = 0 /\
        ~(x = y)
        ==> a * (x + y) + b = 0>>;;

time complex_qelim
 <<forall a b c.
        (a * x^2 + b * x + c = 0) /\
        (2 * a * y^2 + 2 * b * y + 2 * c = 0) /\
        ~(x = y)
        ==> (a * (x + y) + b = 0)>>;;

complex_qelim_all
 <<~(y_1 = 2 * y_3 /\
    y_2 = 2 * y_4 /\
    y_1 * y_3 = y_2 * y_4 /\
    (y_1^2 - y_2^2) * z = 1)>>;;

time complex_qelim <<forall y_1 y_2 y_3 y_4.
       (y_1 = 2 * y_3) /\
       (y_2 = 2 * y_4) /\
       (y_1 * y_3 = y_2 * y_4)
       ==> (y_1^2 = y_2^2)>>;;

complex_qelim_all
 <<(x1 = u3) /\
  (x1 * (u2 - u1) = x2 * u3) /\
  (x4 * (x2 - u1) = x1 * (x3 - u1)) /\
  (x3 * u3 = x4 * u2) /\
  ~(u1 = 0) /\
  ~(u3 = 0)
  ==> (x3^2 + x4^2 = (u2 - x3)^2 + (u3 - x4)^2)>>;;

complex_qelim_all
 <<(u1 * x1 - u1 * u3 = 0) /\
  (u3 * x2 - (u2 - u1) * x1 = 0) /\
  (x1 * x4 - (x2 - u1) * x3 - u1 * x1 = 0) /\
  (u3 * x4 - u2 * x3 = 0) /\
  ~(u1 = 0) /\
  ~(u3 = 0)
  ==> (2 * u2 * x4 + 2 * u3 * x3 - u3^2 - u2^2 = 0)>>;;

complex_qelim_all
 <<(y1 * y3 + x1 * x3 = 0) /\
  (y3 * (y2 - y3) + (x2 - x3) * x3 = 0) /\
  ~(x3 = 0) /\
  ~(y3 = 0)
  ==> (y1 * (x2 - x3) = x1 * (y2 - y3))>>;;

time complex_qelim 
  <<forall y.
         a * x^2 + b * x + c = 0 /\
         a * y^2 + b * y + c = 0 /\
         ~(x = y)
         ==> a * x * y = c /\ a * (x + y) + b = 0>>;;

complex_qelim_all
 <<a * x^2 + b * x + c = 0 /\
    a * y^2 + b * y + c = 0 /\
    ~(x = y)
    ==> a * x * y = c /\ a * (x + y) + b = 0>>;;

(* ------------------------------------------------------------------------- *)
(* Checking resultants from Maple.                                           *)
(* ------------------------------------------------------------------------- *)

time complex_qelim
<<forall a b c.
   (exists x. a * x^2 + b * x + c = 0 /\ 2 * a * x + b = 0) \/ (a = 0) <=>
   (4*a^2*c-b^2*a = 0)>>;;

time complex_qelim
<<forall a b c d e.
  (exists x. a * x^2 + b * x + c = 0 /\ d * x + e = 0) \/
   a = 0 /\ d = 0 <=> d^2*c-e*d*b+a*e^2 = 0>>;;

time complex_qelim
<<forall a b c d e f.
   (exists x. a * x^2 + b * x + c = 0 /\ d * x^2 + e * x + f = 0) \/
   (a = 0) /\ (d = 0) <=>
   d^2*c^2-2*d*c*a*f+a^2*f^2-e*d*b*c-e*b*a*f+a*e^2*c+f*d*b^2 = 0>>;;

(* ------------------------------------------------------------------------- *)
(* Some trigonometric addition formulas (checking stuff from Maple).         *)
(* ------------------------------------------------------------------------- *)

time complex_qelim
  <<forall x y. x^2 + y^2 = 1 ==> (2 * y^2 - 1)^2 + (2 * x * y)^2 = 1>>;;

(* ------------------------------------------------------------------------- *)
(* The examples from my thesis.                                              *)
(* ------------------------------------------------------------------------- *)

time complex_qelim <<forall s c. s^2 + c^2 = 1
      ==> 2 * s - (2 * s * c * c - s^3) = 3 * s^3>>;;

time complex_qelim <<forall u v.
  -((((9 * u^8) * v) * v - (u * u^9)) * 128) -
     (((7 * u^6) * v) * v - (u * u^7)) * 144 -
     (((5 * u^4) * v) * v - (u * u^5)) * 168 -
     (((3 * u^2) * v) * v - (u * u^3)) * 210 -
     (v * v - (u * u)) * 315 + 315 - 1280 * u^10 =
   (-(1152) * u^8 - 1008 * u^6 - 840 * u^4 - 630 * u^2 - 315) *
   (u^2 + v^2 - 1)>>;;

time complex_qelim <<forall u v.
        u^2 + v^2 = 1
        ==> (((9 * u^8) * v) * v - (u * u^9)) * 128 +
            (((7 * u^6) * v) * v - (u * u^7)) * 144 +
            (((5 * u^4) * v) * v - (u * u^5)) * 168 +
            (((3 * u^2) * v) * v - (u * u^3)) * 210 +
            (v * v - (u * u)) * 315 + 1280 * u^10 = 315>>;;

(* ------------------------------------------------------------------------- *)
(* Deliberately silly examples from Poizat's model theory book (6.6).        *)
(* ------------------------------------------------------------------------- *)

time complex_qelim <<exists z. x * z^87 + y * z^44 + 1 = 0>>;;

time complex_qelim <<forall u. exists v. x * (u + v^2)^2 + y * (u + v^2) + z = 0>>;;

(* ------------------------------------------------------------------------- *)
(* Actually prove simple equivalences.                                       *)
(* ------------------------------------------------------------------------- *)

time complex_qelim <<forall x y. (exists z. x * z^87 + y * z^44 + 1 = 0)
                  <=> ~(x = 0) \/ ~(y = 0)>>;;

time complex_qelim <<forall x y z. (forall u. exists v.
                         x * (u + v^2)^2 + y * (u + v^2) + z = 0)
                    <=> ~(x = 0) \/ ~(y = 0) \/ z = 0>>;;

(* ------------------------------------------------------------------------- *)
(* Invertibility of 2x2 matrix in terms of nonzero determinant.              *)
(* ------------------------------------------------------------------------- *)

time complex_qelim <<exists w x y z. (a * w + b * y = 1) /\
                      (a * x + b * z = 0) /\
                      (c * w + d * y = 0) /\
                      (c * x + d * z = 1)>>;;

time complex_qelim <<forall a b c d.
        (exists w x y z. (a * w + b * y = 1) /\
                         (a * x + b * z = 0) /\
                         (c * w + d * y = 0) /\
                         (c * x + d * z = 1))
        <=> ~(a * d = b * c)>>;;

*)
(* ========================================================================= *)
(* Grobner basis algorithm.                                                  *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

(* ------------------------------------------------------------------------- *)
(* Monomial ordering.                                                        *)
(* ------------------------------------------------------------------------- *)

let morder_lt m1 m2 =
  let n1 = itlist (+) m1 0 and n2 = itlist (+) m2 0 in
  n1 < n2 or n1 = n2 & lexord(>) (rev m1) (rev m2);;

(* ------------------------------------------------------------------------- *)
(* Arithmetic on canonical polynomials.                                      *)
(* ------------------------------------------------------------------------- *)

let rec grob_add l1 l2 =
  match (l1,l2) with
    ([],l2) -> l2
  | (l1,[]) -> l1
  | ((c1,m1)::o1,(c2,m2)::o2) ->
        if m1 = m2 then
          let c = c1+/c2 and rest = grob_add o1 o2 in
          if c =/ Int 0 then rest else (c,m1)::rest
        else if morder_lt m2 m1 then (c1,m1)::(grob_add o1 l2)
        else (c2,m2)::(grob_add l1 o2);;

let grob_mmul (c1,m1) (c2,m2) = (c1*/c2,map2 (+) m1 m2);;

let rec grob_cmul cm pol = map (grob_mmul cm) pol;;

let grob_neg = map (fun (c,m) -> (minus_num c,m));;

let grob_sub l1 l2 = grob_add l1 (grob_neg l2);;

let rec grob_mul l1 l2 =
  match l1 with
    [] -> []
  | (h1::t1) -> grob_add (grob_cmul h1 l2) (grob_mul t1 l2);;

(* ------------------------------------------------------------------------- *)
(* Monomial division operation.                                              *)
(* ------------------------------------------------------------------------- *)

let mdiv =
  let index_sub n1 n2 = if n1 < n2 then failwith "mdiv" else n1-n2 in
  fun (c1,m1) (c2,m2) -> (c1//c2,map2 index_sub m1 m2);;

(* ------------------------------------------------------------------------- *)
(* Reduce monomial cm by polynomial pol, returning replacement for cm.       *)
(* ------------------------------------------------------------------------- *)

let reduce1 cm pol =
  match pol with
    [] -> failwith "reduce1"
  | hm::cms -> let (c,m) = mdiv cm hm in grob_cmul (minus_num c,m) cms;;

(* ------------------------------------------------------------------------- *)
(* Try this for all polynomials in a basis.                                  *)
(* ------------------------------------------------------------------------- *)

let reduceb cm basis = tryfind (reduce1 cm) basis;;

(* ------------------------------------------------------------------------- *)
(* Reduction of a polynomial (always picking largest monomial possible).     *)
(* ------------------------------------------------------------------------- *)

let rec reduce basis pol =
  match pol with
    [] -> []
  | cm::ptl -> try reduce basis (grob_add (reduceb cm basis) ptl)
               with Failure _ -> cm::(reduce basis ptl);;

(* ------------------------------------------------------------------------- *)
(* Lowest common multiple of two monomials.                                  *)
(* ------------------------------------------------------------------------- *)

let mlcm (c1,m1) (c2,m2) = (Int 1,map2 max m1 m2);;

(* ------------------------------------------------------------------------- *)
(* Compute S-polynomial of two polynomials (zero for the orthogonal case).   *)
(* ------------------------------------------------------------------------- *)

let spoly pol1 pol2 =
  match (pol1,pol2) with
    ([],p) -> []
  | (p,[]) -> []
  | (m1::ptl1,m2::ptl2) ->
        let m = mlcm m1 m2 in
        if snd(m) = snd(grob_mmul  m1 m2) then []
        else grob_sub (grob_cmul (mdiv m m1) ptl1)
                      (grob_cmul (mdiv m m2) ptl2);;

(* ------------------------------------------------------------------------- *)
(* Grobner basis algorithm.                                                  *)
(* ------------------------------------------------------------------------- *)

let rec grobner basis pairs =
  print_string(string_of_int(length basis)^" basis elements and "^
               string_of_int(length pairs)^" pairs");
  print_newline();
  match pairs with
    [] -> basis
  | (p1,p2)::opairs ->
        let sp = reduce basis (spoly p1 p2) in
        if sp = [] then grobner basis opairs
        else if forall (forall ((=) 0) ** snd) sp then [sp] else
        let newcps = map (fun p -> p,sp) basis in
        grobner (sp::basis) (opairs @ newcps);;

(* ------------------------------------------------------------------------- *)
(* Overall function.                                                         *)
(* ------------------------------------------------------------------------- *)

let groebner basis = grobner basis (distinctpairs basis);;

(* ------------------------------------------------------------------------- *)
(* Convert formula into canonical form.                                      *)
(* ------------------------------------------------------------------------- *)

let grob_var vars x =
  [Int 1,map (fun y -> if y = x then 1 else 0) vars]

let grob_const vars n =
  if n =/ Int 0 then [] else [n,map (fun k -> 0) vars];;

let rec grobterm vars tm =
  match tm with
    Var x -> grob_var vars x
  | Fn("-",[t]) -> grob_neg (grobterm vars t)
  | Fn("+",[s;t]) ->
        grob_add (grobterm vars s) (grobterm vars t)
  | Fn("-",[s;t]) ->
        grob_sub (grobterm vars s) (grobterm vars t)
  | Fn("*",[s;t]) ->
        grob_mul (grobterm vars s) (grobterm vars t)
  | Fn("^",[t;n]) ->
        funpow (int_of_num(dest_numeral n))
               (grob_mul (grobterm vars t)) (grob_const vars (Int 1))
  | _ -> grob_const vars (dest_numeral tm);;

let grobatom vars fm =
  match fm with
    Atom(R("=",[s;t])) -> grobterm vars (Fn("-",[s;t]))
  | _ -> failwith "grobatom: not an equation";;

(* ------------------------------------------------------------------------- *)
(* Use the Rabinowitsch trick to eliminate inequations.                      *)
(* That is, replace p =/= 0 by exists w. p * w = 1.                          *)
(* ------------------------------------------------------------------------- *)

let rabinowitsch vars v p =
   grob_sub (grob_const vars (Int 1)) (grob_mul (grob_var vars v) p);;

(* ------------------------------------------------------------------------- *)
(* Universal complex number decision procedure based on Grobner bases.       *)
(* ------------------------------------------------------------------------- *)

let grobner_trivial fms =
  let vars0 = itlist (union ** fv) fms []
  and eqs,neqs = partition positive fms in
  let rvs = map (fun n -> variant ("_"^string_of_int n) vars0)
                (1--length neqs) in
  let vars = vars0 @ rvs in
  let poleqs = map (grobatom vars) eqs
  and polneqs = map (grobatom vars ** negate) neqs in
  let pols = poleqs @ map2 (rabinowitsch vars) rvs polneqs in
  reduce (groebner pols) (grob_const vars (Int 1)) = [];;

let grobner_decide fm =
  let fm1 = specialize(prenex(nnf(simplify fm))) in
  forall grobner_trivial (simpdnf(nnf(Not fm1)));;

(* ------------------------------------------------------------------------- *)
(* Examples.                                                                 *)
(* ------------------------------------------------------------------------- *)

(*
grobner_decide
  <<a^2 = 2 /\ x^2 + a*x + 1 = 0 ==> x^4 + 1 = 0>>;;

grobner_decide
  <<a^2 = 2 /\ x^2 + a*x + 1 = 0 ==> x^4 + 2 = 0>>;;

grobner_decide
  <<(a * x^2 + b * x + c = 0) /\
   (a * y^2 + b * y + c = 0) /\
   ~(x = y)
   ==> (a * x * y = c) /\ (a * (x + y) + b = 0)>>;;

(* ------------------------------------------------------------------------- *)
(* Compare with earlier procedure.                                           *)
(* ------------------------------------------------------------------------- *)

let fm =
  <<(a * x^2 + b * x + c = 0) /\
    (a * y^2 + b * y + c = 0) /\
    ~(x = y)
    ==> (a * x * y = c) /\ (a * (x + y) + b = 0)>> in
time complex_qelim (generalize fm),time grobner_decide fm;;

(* ------------------------------------------------------------------------- *)
(* More tests.                                                               *)
(* ------------------------------------------------------------------------- *)

time grobner_decide  <<a^2 = 2 /\ x^2 + a*x + 1 = 0 ==> x^4 + 1 = 0>>;;

time grobner_decide  <<a^2 = 2 /\ x^2 + a*x + 1 = 0 ==> x^4 + 2 = 0>>;;

time grobner_decide <<(a * x^2 + b * x + c = 0) /\
      (a * y^2 + b * y + c = 0) /\
      ~(x = y)
      ==> (a * x * y = c) /\ (a * (x + y) + b = 0)>>;;

time grobner_decide
 <<(y_1 = 2 * y_3) /\
  (y_2 = 2 * y_4) /\
  (y_1 * y_3 = y_2 * y_4)
  ==> (y_1^2 = y_2^2)>>;;

time grobner_decide
 <<(x1 = u3) /\
  (x1 * (u2 - u1) = x2 * u3) /\
  (x4 * (x2 - u1) = x1 * (x3 - u1)) /\
  (x3 * u3 = x4 * u2) /\
  ~(u1 = 0) /\
  ~(u3 = 0)
  ==> (x3^2 + x4^2 = (u2 - x3)^2 + (u3 - x4)^2)>>;;

time grobner_decide
 <<(u1 * x1 - u1 * u3 = 0) /\
  (u3 * x2 - (u2 - u1) * x1 = 0) /\
  (x1 * x4 - (x2 - u1) * x3 - u1 * x1 = 0) /\
  (u3 * x4 - u2 * x3 = 0) /\
  ~(u1 = 0) /\
  ~(u3 = 0)
  ==> (2 * u2 * x4 + 2 * u3 * x3 - u3^2 - u2^2 = 0)>>;;

(* ------------------------------------------------------------------------- *)
(* Checking resultants (in one direction).                                   *)
(* ------------------------------------------------------------------------- *)

time grobner_decide
<<a * x^2 + b * x + c = 0 /\ 2 * a * x + b = 0
 ==> 4*a^2*c-b^2*a = 0>>;;

time grobner_decide
<<a * x^2 + b * x + c = 0 /\ d * x + e = 0
 ==> d^2*c-e*d*b+a*e^2 = 0>>;;

time grobner_decide
<<a * x^2 + b * x + c = 0 /\ d * x^2 + e * x + f = 0
 ==> d^2*c^2-2*d*c*a*f+a^2*f^2-e*d*b*c-e*b*a*f+a*e^2*c+f*d*b^2 = 0>>;;

*)
(* ========================================================================= *)
(* Real quantifier elimination (using Hormander's algorithm).                *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

(* ------------------------------------------------------------------------- *)
(* Evaluate a quantifier-free formula given a sign matrix row for its polys. *)
(* ------------------------------------------------------------------------- *)

let rec testform pmat fm =
  match fm with
    Atom(R(a,[p;Fn("0",[])])) ->
        let s = assoc p pmat in
        if a = "=" then s = Zero
        else if a = "<=" then s = Zero or s = Negative
        else if a = ">=" then s = Zero or s = Positive
        else if a = "<" then s = Negative
        else if a = ">" then s = Positive
        else failwith "testform: unknown literal"
  | False -> false
  | True -> true
  | Not(p) -> not(testform pmat p)
  | And(p,q) -> testform pmat p & testform pmat q
  | Or(p,q) -> testform pmat p or testform pmat q
  | Imp(p,q) -> not(testform pmat p) or testform pmat q
  | Iff(p,q) -> (testform pmat p = testform pmat q)
  | _ -> failwith "testform: non-propositional formula";;

(* ------------------------------------------------------------------------- *)
(* Infer sign of p(x) at points from corresponding qi(x) with pi(x) = 0      *)
(* ------------------------------------------------------------------------- *)

let inferpsign pd qd =
  try let i = index Zero pd in el i qd :: pd
  with Failure _ -> Nonzero :: pd;;

(* ------------------------------------------------------------------------- *)
(* Condense subdivision by removing points with no relevant zeros.           *)
(* ------------------------------------------------------------------------- *)

let rec condense ps =
  match ps with
    int::pt::other -> let rest = condense other in
                      if mem Zero pt then int::pt::rest else rest
  | _ -> ps;;

(* ------------------------------------------------------------------------- *)
(* Infer sign on intervals (use with infinities at end) and split if needed  *)
(* ------------------------------------------------------------------------- *)

let rec inferisign ps =
  match ps with
    pt1::int::pt2::other ->
      let res = inferisign(pt2::other)
      and tint = tl int and s1 = hd pt1 and s2 = hd pt2 in
      if s1 = Positive & s2 = Negative then
        pt1::(Positive::tint)::(Zero::tint)::(Negative::tint)::res
      else if s1 = Negative & s2 = Positive then
        pt1::(Negative::tint)::(Zero::tint)::(Positive::tint)::res
      else if (s1 = Positive or s2 = Negative) & s1 = s2 then
        pt1::(s1::tint)::res
      else if s1 = Zero & s2 = Zero then
        failwith "inferisign: inconsistent"
      else if s1 = Zero then
        pt1::(s2 :: tint)::res
      else if s2 = Zero then
        pt1::(s1 :: tint)::res
      else failwith "inferisign: can't infer sign on interval"
  | _ -> ps;;

(* ------------------------------------------------------------------------- *)
(* Deduce matrix for p,p1,...,pn from matrix for p',p1,...,pn,q0,...,qn      *)
(* where qi = rem(p,pi) with p0 = p'                                         *)
(* ------------------------------------------------------------------------- *)

let dedmatrix cont mat =
  let n = length (hd mat) / 2 in
  let mat1,mat2 = unzip (map (chop_list n) mat) in
  let mat3 = map2 inferpsign mat1 mat2 in
  let mat4 = condense mat3 in
  let k = length(hd mat4) in
  let mats = (replicate k (swap true (el 1 (hd mat3))))::mat4@
             [replicate k (el 1 (last mat3))] in
  let mat5 = butlast(tl(inferisign mats)) in
  let mat6 = map (fun l -> hd l :: tl(tl l)) mat5 in
  cont(condense mat6);;

(* ------------------------------------------------------------------------- *)
(* Pseudo-division making sure the remainder has the same sign.              *)
(* ------------------------------------------------------------------------- *)

let pdivides vars sgns q p =
  let s = findsign vars sgns (head vars p) in
  if s = Zero then failwith "pdivides: head coefficient is zero" else
  let (k,r) = pdivide vars q p in
  if s = Negative & k mod 2 = 1 then poly_neg r
  else if s = Positive or k mod 2 = 0 then r
  else poly_mul (tl vars) (head vars p) r;;

(* ------------------------------------------------------------------------- *)
(* Case splitting for positive/negative (assumed nonzero).                   *)
(* ------------------------------------------------------------------------- *)

let split_sign vars sgns pol cont_p cont_n =
  let s = findsign vars sgns pol in
  if s = Positive then cont_p sgns
  else if s = Negative then cont_n sgns
  else if s = Zero then failwith "split_sign: zero polynomial" else
  let ineq = Atom(R(">",[pol; Fn("0",[])])) in
  Or(And(ineq,cont_p (assertsign vars sgns (pol,Positive))),
     And(Not ineq,cont_n (assertsign vars sgns (pol,Negative))));;

(* ------------------------------------------------------------------------- *)
(* Formal derivative of polynomial.                                          *)
(* ------------------------------------------------------------------------- *)

let rec poly_diff_aux vars n p =
  let np = mk_numeral(Int n) in
  match p with
    Fn("+",[c; Fn("*",[Var x; q])]) when x = hd vars ->
        Fn("+",[poly_mul (tl vars) np c;
                Fn("*",[Var x; poly_diff_aux vars (n+1) q])])
  | _ -> poly_mul vars np p;;

let poly_diff vars p =
  match p with
    Fn("+",[c; Fn("*",[Var x; q])]) when x = hd vars ->
        poly_diff_aux vars 1 q
  | _ -> Fn("0",[]);;

(* ------------------------------------------------------------------------- *)
(* Modifiy cont to insert constant sign into a sign matrix at position i.    *)
(* ------------------------------------------------------------------------- *)

let matinsert i s cont mat = cont (map (insertat i s) mat);;

(* ------------------------------------------------------------------------- *)
(* Continuation will just return false if assignments are inconsistent.      *)
(* ------------------------------------------------------------------------- *)

let trapout cont m =
  try cont m with Failure "inferisign: inconsistent" -> False;;

(* ------------------------------------------------------------------------- *)
(* Find matrix and apply continuation; split over coefficient zero and signs *)
(* ------------------------------------------------------------------------- *)

let rec matrix vars pols cont sgns =
  if pols = [] then trapout cont [[]] else
  if exists (is_constant vars) pols then
    let p = find (is_constant vars) pols in
    let i = index p pols in
    let pols1,pols2 = chop_list i pols in
    let pols' = pols1 @ tl pols2 in
    matrix vars pols' (matinsert i (findsign vars sgns p) cont) sgns
  else
    let d = itlist (max ** degree vars) pols (-1) in
    let p = find (fun p -> degree vars p = d) pols in
    let p' = poly_diff vars p and i = index p pols in
    let qs = let p1,p2 = chop_list i pols in p'::p1 @ tl p2 in
    let gs = map (pdivides vars sgns p) qs in
    let cont' m = cont(map (fun l -> insertat i (hd l) (tl l)) m) in
    splitzero vars qs gs (dedmatrix cont') sgns

and splitzero vars dun pols cont sgns =
  match pols with
    [] -> splitsigns vars [] dun cont sgns
  | p::ops -> if p = Fn("0",[]) then
                let cont' = matinsert (length dun) Zero cont in
                splitzero vars dun ops cont' sgns
              else split_zero (tl vars) sgns (head vars p)
                    (splitzero vars dun (behead vars p :: ops) cont)
                    (splitzero vars (dun@[p]) ops cont)

and splitsigns vars dun pols cont sgns =
  match pols with
    [] -> monicize vars dun cont sgns
  | p::ops -> let cont' = splitsigns vars (dun@[p]) ops cont in
              split_sign (tl vars) sgns (head vars p) cont' cont'

and monicize vars pols cont sgns =
  let mols,swaps = unzip(map (monic vars) pols) in
  let sols = setify mols in
  let indices = map (fun p -> index p sols) mols in
  let transform m =
    map2 (fun sw i -> swap sw (el i m)) swaps indices in
  let cont' mat = cont(map transform mat) in
  matrix vars sols cont' sgns;;

(* ------------------------------------------------------------------------- *)
(* Overall quelim for exists x. literal_1(x) /\ ... /\ literal_n(x)          *)
(* ------------------------------------------------------------------------- *)

let rec polynomials fm =
  atom_union (function (R(a,[p;Fn("0",[])])) -> [p] | _ -> []) fm;;

let basic_real_qelim vars fm =
  let Exists(x,bod) = fm in
  let pols = polynomials bod in
  let cont mat =
    if exists (fun m -> testform (zip pols m) bod) mat
    then True else False in
  splitzero (x::vars) [] pols cont [Fn("1",[]),Positive];;

let real_qelim =
  simplify ** evalc **
  lift_qelim polyatom (simplify ** evalc) basic_real_qelim;;

(* ------------------------------------------------------------------------- *)
(* Sometimes it may pay to use DNF but we don't have to.                     *)
(* ------------------------------------------------------------------------- *)

let real_qelim' =
  simplify ** evalc **
  lift_qelim polyatom (dnf ** cnnf (fun x -> x) ** evalc)
                      basic_real_qelim;;

(* ------------------------------------------------------------------------- *)
(* Examples.                                                                 *)
(* ------------------------------------------------------------------------- *)

(*
real_qelim <<exists x. x^4 + x^2 + 1 = 0>>;;

real_qelim <<exists x. x^3 - x^2 + x - 1 = 0>>;;

real_qelim
 <<exists x y. x^3 - x^2 + x - 1 = 0 /\
              y^3 - y^2 + y - 1 = 0 /\ ~(x = y)>>;;

real_qelim
 <<forall a f k. (forall e. k < e ==> f < a * e) ==> f <= a * k>>;;

real_qelim
 <<exists x. a * x^2 + b * x + c = 0>>;;

real_qelim
 <<forall a b c. (exists x. a * x^2 + b * x + c = 0) <=>
                 b^2 >= 4 * a * c>>;;

real_qelim
 <<forall a b c. (exists x. a * x^2 + b * x + c = 0) <=>
                 a = 0 /\ (~(b = 0) \/ c = 0) \/
                 ~(a = 0) /\ b^2 >= 4 * a * c>>;;

(* ------------------------------------------------------------------------- *)
(* Termination ordering for group theory completion.                         *)
(* ------------------------------------------------------------------------- *)

real_qelim
 <<1 < 2 /\
   (forall x. 1 < x ==> 1 < x^2) /\
   (forall x y. 1 < x /\ 1 < y ==> 1 < x * (1 + 2 * y))>>;;
*)

let rec trans tm =
  match tm with
    Fn("*",[s;t]) -> Fn("*",[trans s;
                             Fn("+",[Fn("1",[]);
                                     Fn("*",[Fn("2",[]); trans t])])])
  | Fn("i",[t]) -> Fn("^",[trans t; Fn("2",[])])
  | Fn("1",[]) -> Fn("2",[])
  | Var x -> tm;;

let transeq (Atom(R("=",[s;t]))) = Atom(R(">",[trans s; trans t]));;

let supergen fm =
  itlist (fun x p -> Forall(x,Imp(Atom(R(">",[Var x; Fn("1",[])])),p)))
         (fv fm) fm;;

(*
let eqs = complete_and_simplify ["1"; "*"; "i"]
  [<<1 * x = x>>; <<i(x) * x = 1>>; <<(x * y) * z = x * y * z>>];;

let fm = list_conj (map (supergen ** transeq) eqs);;

real_qelim fm;;
*)

(* ------------------------------------------------------------------------- *)
(* This one works better using DNF.                                          *)
(* ------------------------------------------------------------------------- *)

(*                                                            

real_qelim'
 <<forall d.
     (exists c. forall a b. (a = d /\ b = c) \/ (a = c /\ b = 1)
                            ==> a^2 = b)
     <=> d^4 = 1>>;;

(* ------------------------------------------------------------------------- *)
(* Linear examples.                                                          *)
(* ------------------------------------------------------------------------- *)

time real_qelim <<exists x. x - 1 > 0>>;;

time real_qelim <<exists x. 3 - x > 0 /\ x - 1 > 0>>;;

(* ------------------------------------------------------------------------- *)
(* Quadratics.                                                               *)
(* ------------------------------------------------------------------------- *)

time real_qelim <<exists x. x^2 = 0>>;;

time real_qelim <<exists x. x^2 + 1 = 0>>;;

time real_qelim <<exists x. x^2 - 1 = 0>>;;

time real_qelim <<exists x. x^2 - 2 * x + 1 = 0>>;;

time real_qelim <<exists x. x^2 - 3 * x + 1 = 0>>;;

(* ------------------------------------------------------------------------- *)
(* Cubics.                                                                   *)
(* ------------------------------------------------------------------------- *)

time real_qelim <<exists x. x^3 - 1 > 0>>;;

time real_qelim <<exists x. x^3 - 3 * x^2 + 3 * x - 1 > 0>>;;

time real_qelim <<exists x. x^3 - 4 * x^2 + 5 * x - 2 > 0>>;;

time real_qelim <<exists x. x^3 - 6 * x^2 + 11 * x - 6 = 0>>;;

(* ------------------------------------------------------------------------- *)
(* Quartics.                                                                 *)
(* ------------------------------------------------------------------------- *)

time real_qelim <<exists x. x^4 - 1 > 0>>;;

time real_qelim <<exists x. x^4 + 1 > 0>>;;

time real_qelim <<exists x. x^4 = 0>>;;

time real_qelim <<exists x. x^4 - x^3 = 0>>;;

time real_qelim <<exists x. x^4 - x^2 = 0>>;;

time real_qelim <<exists x. x^4 - 2 * x^2 + 2 = 0>>;;

(* ------------------------------------------------------------------------- *)
(* Quintics.                                                                 *)
(* ------------------------------------------------------------------------- *)

time real_qelim
  <<exists x. x^5 - 15 * x^4 + 85 * x^3 - 225 * x^2 + 274 * x - 120 = 0>>;;

(* ------------------------------------------------------------------------- *)
(* Sextics(?)                                                                *)
(* ------------------------------------------------------------------------- *)

time real_qelim <<exists x.
 x^6 - 21 * x^5 + 175 * x^4 - 735 * x^3 + 1624 * x^2 - 1764 * x + 720 = 0>>;;

time real_qelim <<exists x.
 x^6 - 12 * x^5 + 56 * x^4 - 130 * x^3 + 159 * x^2 - 98 * x + 24 = 0>>;;

(* ------------------------------------------------------------------------- *)
(* Multiple polynomials.                                                     *)
(* ------------------------------------------------------------------------- *)

time real_qelim <<exists x. x^2 + 2 > 0 /\ x^3 - 11 = 0 /\ x + 131 >= 0>>;;

(* ------------------------------------------------------------------------- *)
(* With more variables.                                                      *)
(* ------------------------------------------------------------------------- *)

time real_qelim <<exists x. a * x^2 + b * x + c = 0>>;;

time real_qelim <<exists x. a * x^3 + b * x^2 + c * x + d = 0>>;;

(* ------------------------------------------------------------------------- *)
(* Constraint solving.                                                       *)
(* ------------------------------------------------------------------------- *)

time real_qelim <<exists x1 x2. x1^2 + x2^2 - u1 <= 0 /\ x1^2 - u2 > 0>>;;

(* ------------------------------------------------------------------------- *)
(* Huet & Oppen (interpretation of group theory).                            *)
(* ------------------------------------------------------------------------- *)

time real_qelim <<forall x y. x > 0 /\ y > 0 ==> x * (1 + 2 * y) > 0>>;;

(* ------------------------------------------------------------------------- *)
(* Other examples.                                                           *)
(* ------------------------------------------------------------------------- *)

time real_qelim
  <<forall a f k. (forall e. k < e ==> f < a * e) ==> f <= a * k>>;;

time real_qelim <<exists x. x^2 - x + 1 = 0>>;;

time real_qelim <<exists x. x^2 - 3 * x + 1 = 0>>;;

time real_qelim <<exists x. x > 6 /\ (x^2 - 3 * x + 1 = 0)>>;;

time real_qelim <<exists x. 7 * x^2 - 5 * x + 3 > 0 /\
                            x^2 - 3 * x + 1 = 0>>;;

time real_qelim <<exists x. 11 * x^3 - 7 * x^2 - 2 * x + 1 = 0 /\
                            7 * x^2 - 5 * x + 3 > 0 /\
                            x^2 - 8 * x + 1 = 0>>;;

time real_qelim <<exists x. a * x^2 + b * x + c = 0>>;;

(* ------------------------------------------------------------------------- *)
(* Quadratic inequality from Liska and Steinberg                             *)
(* ------------------------------------------------------------------------- *)

time real_qelim
 <<forall x. -(1) <= x /\ x <= 1 ==>
      C * (x - 1) * (4 * x * a * C - x * C - 4 * a * C + C - 2) >= 0>>;;


(* ------------------------------------------------------------------------- *)
(* Metal-milling example from Loos and Weispfenning                          *)
(* ------------------------------------------------------------------------- *)

time real_qelim
  <<exists x y. 0 < x /\
                y < 0 /\
                x * r - x * t + t = q * x - s * x + s /\
                x * b - x * d + d = a * y - c * y + c>>;;


(* ------------------------------------------------------------------------- *)
(* Linear example from Collins and Johnson                                   *)
(* ------------------------------------------------------------------------- *)

time real_qelim
 <<exists r. 0 < r /\
      r < 1 /\
      0 < (1 - 3 * r) * (a^2 + b^2) + 2 * a * r /\
      (2 - 3 * r) * (a^2 + b^2) + 4 * a * r - 2 * a - r < 0>>;;


(* ------------------------------------------------------------------------- *)
(* Dave Griffioen #4                                                         *)
(* ------------------------------------------------------------------------- *)

time real_qelim
 <<forall x y. (1 - t) * x <= (1 + t) * y /\ (1 - t) * y <= (1 + t) * x
         ==> 0 <= y>>;;

(* ------------------------------------------------------------------------- *)
(* Some examples from "Real Quantifier Elimination in practice".             *)
(* ------------------------------------------------------------------------- *)

time real_qelim <<exists x2. x1^2 + x2^2 <= u1 /\ x1^2 > u2>>;;

time real_qelim <<exists x1 x2. x1^2 + x2^2 <= u1 /\ x1^2 > u2>>;;

time real_qelim
 <<forall x1 x2. x1 + x2 <= 2 /\ x1 <= 1 /\ x1 >= 0 /\ x2 >= 0
           ==> 3 * (x1 + 3 * x2^2 + 2) <= 8 * (2 * x1 + x2 + 1)>>;;

(* ------------------------------------------------------------------------- *)
(* From Collins & Johnson's "Sign variation..." article.                     *)
(* ------------------------------------------------------------------------- *)

time real_qelim <<exists r. 0 < r /\ r < 1 /\
                (1 - 3 * r) * (a^2 + b^2) + 2 * a * r > 0 /\
                (2 - 3 * r) * (a^2 + b^2) + 4 * a * r - 2 * a - r < 0>>;;

(* ------------------------------------------------------------------------- *)
(* From "Parallel implementation of CAD" article.                            *)
(* ------------------------------------------------------------------------- *)

time real_qelim <<exists x. forall y. x^2 + y^2 > 1 /\ x * y >= 1>>;;

(* ------------------------------------------------------------------------- *)
(* Other misc examples.                                                      *)
(* ------------------------------------------------------------------------- *)

time real_qelim <<forall x y. x^2 + y^2 = 1 ==> 2 * x * y <= 1>>;;

time real_qelim <<forall x y. x^2 + y^2 = 1 ==> 2 * x * y < 1>>;;

time real_qelim <<forall x y. x * y > 0 <=> x > 0 /\ y > 0 \/ x < 0 /\ y < 0>>;;

time real_qelim <<exists x y. x > y /\ x^2 < y^2>>;;

time real_qelim <<forall x y. x < y ==> exists z. x < z /\ z < y>>;;

time real_qelim <<forall x. 0 < x <=> exists y. x * y^2 = 1>>;;

time real_qelim <<forall x. 0 <= x <=> exists y. x * y^2 = 1>>;;

time real_qelim <<forall x. 0 <= x <=> exists y. x = y^2>>;;

time real_qelim <<forall x y. 0 < x /\ x < y ==> exists z. x < z^2 /\ z^2 < y>>;;

time real_qelim <<forall x y. x < y ==> exists z. x < z^2 /\ z^2 < y>>;;

time real_qelim <<forall x y. x^2 + y^2 = 0 ==> x = 0 /\ y = 0>>;;

time real_qelim <<forall x y z. x^2 + y^2 + z^2 = 0 ==> x = 0 /\ y = 0 /\ z = 0>>;;

time real_qelim <<forall w x y z. w^2 + x^2 + y^2 + z^2 = 0
                      ==> w = 0 /\ x = 0 /\ y = 0 /\ z = 0>>;;

time real_qelim <<forall a. a^2 = 2 ==> forall x. ~(x^2 + a*x + 1 = 0)>>;;

time real_qelim <<forall a. a^2 = 2 ==> forall x. ~(x^2 - a*x + 1 = 0)>>;;

time real_qelim <<forall x y. x^2 = 2 /\ y^2 = 3 ==> (x * y)^2 = 6>>;;

time real_qelim <<forall x. exists y. x^2 = y^3>>;;

time real_qelim <<forall x. exists y. x^3 = y^2>>;;

time real_qelim
 <<forall a b c.
        (a * x^2 + b * x + c = 0) /\
        (a * y^2 + b * y + c = 0) /\
        ~(x = y)
        ==> (a * (x + y) + b = 0)>>;;

time real_qelim
 <<forall y_1 y_2 y_3 y_4.
     (y_1 = 2 * y_3) /\
     (y_2 = 2 * y_4) /\
     (y_1 * y_3 = y_2 * y_4)
     ==> (y_1^2 = y_2^2)>>;;

time real_qelim <<forall x. x^2 < 1 <=> x^4 < 1>>;;

(* ------------------------------------------------------------------------- *)
(* Counting roots.                                                           *)
(* ------------------------------------------------------------------------- *)

time real_qelim <<exists x. x^3 - x^2 + x - 1 = 0>>;;

time real_qelim
  <<exists x y. x^3 - x^2 + x - 1 = 0 /\ y^3 - y^2 + y - 1 = 0 /\ ~(x = y)>>;;

time real_qelim <<exists x. x^4 + x^2 - 2 = 0>>;;

time real_qelim
  <<exists x y. x^4 + x^2 - 2 = 0 /\ y^4 + y^2 - 2 = 0 /\ ~(x = y)>>;;

time real_qelim
  <<exists x y. x^3 + x^2 - x - 1 = 0 /\ y^3 + y^2 - y - 1 = 0 /\ ~(x = y)>>;;

time real_qelim <<exists x y z. x^3 + x^2 - x - 1 = 0 /\
                    y^3 + y^2 - y - 1 = 0 /\
                    z^3 + z^2 - z - 1 = 0 /\ ~(x = y) /\ ~(x = z)>>;;

(* ------------------------------------------------------------------------- *)
(* Existence of tangents, so to speak.                                       *)
(* ------------------------------------------------------------------------- *)

time real_qelim
  <<forall x y. exists s c. s^2 + c^2 = 1 /\ s * x + c * y = 0>>;;

(* ------------------------------------------------------------------------- *)
(* Another useful thing (componentwise ==> normwise accuracy etc.)           *)
(* ------------------------------------------------------------------------- *)

time real_qelim <<forall x y. (x + y)^2 <= 2 * (x^2 + y^2)>>;;

(* ------------------------------------------------------------------------- *)
(* Some related quantifier elimination problems.                             *)
(* ------------------------------------------------------------------------- *)

time real_qelim <<forall x y. (x + y)^2 <= c * (x^2 + y^2)>>;;

time real_qelim
  <<forall c. (forall x y. (x + y)^2 <= c * (x^2 + y^2)) <=> 2 <= c>>;;

time real_qelim <<forall a b. a * b * c <= a^2 + b^2>>;;

time real_qelim
  <<forall c. (forall a b. a * b * c <= a^2 + b^2) <=> c^2 <= 4>>;;

(* ------------------------------------------------------------------------- *)
(* Tedious lemmas I once proved manually in HOL.                             *)
(* ------------------------------------------------------------------------- *)

time real_qelim
 <<forall a b c. 0 < a /\ 0 < b /\ 0 < c
                 ==> 0 < a * b /\ 0 < a * c /\ 0 < b * c>>;;

time real_qelim
  <<forall a b c. a * b > 0 ==> (c * a < 0 <=> c * b < 0)>>;;

time real_qelim
  <<forall a b c. a * b > 0 ==> (a * c < 0 <=> b * c < 0)>>;;

time real_qelim
  <<forall a b. a < 0 ==> (a * b > 0 <=> b < 0)>>;;

time real_qelim
  <<forall a b c. a * b < 0 /\ ~(c = 0) ==> (c * a < 0 <=> ~(c * b < 0))>>;;

time real_qelim
  <<forall a b. a * b < 0 <=> a > 0 /\ b < 0 \/ a < 0 /\ b > 0>>;;

time real_qelim
  <<forall a b. a * b <= 0 <=> a >= 0 /\ b <= 0 \/ a <= 0 /\ b >= 0>>;;

(* ------------------------------------------------------------------------- *)
(* Vaguely connected with reductions for Robinson arithmetic.                *)
(* ------------------------------------------------------------------------- *)

time real_qelim
  <<forall a b. ~(a <= b) <=> forall d. d <= b ==> d < a>>;;

time real_qelim
  <<forall a b. ~(a <= b) <=> forall d. d <= b ==> ~(d = a)>>;;

time real_qelim
  <<forall a b. ~(a < b) <=> forall d. d < b ==> d < a>>;;

(* ------------------------------------------------------------------------- *)
(* Another nice problem.                                                     *)
(* ------------------------------------------------------------------------- *)

time real_qelim
 <<forall x y. x^2 + y^2 = 1 ==> (x + y)^2 <= 2>>;;

(* ------------------------------------------------------------------------- *)
(* Some variants / intermediate steps in Cauchy-Schwartz inequality.         *)
(* ------------------------------------------------------------------------- *)

time real_qelim
 <<forall x y. 2 * x * y <= x^2 + y^2>>;;

time real_qelim
 <<forall a b c d. 2 * a * b * c * d <= a^2 * b^2 + c^2 * d^2>>;;

time real_qelim
 <<forall x1 x2 y1 y2.
      (x1 * y1 + x2 * y2)^2 <= (x1^2 + x2^2) * (y1^2 + y2^2)>>;;

(* ------------------------------------------------------------------------- *)
(* The determinant example works OK here too.                                *)
(* ------------------------------------------------------------------------- *)

time real_qelim
 <<exists w x y z. (a * w + b * y = 1) /\
                   (a * x + b * z = 0) /\
                   (c * w + d * y = 0) /\
                   (c * x + d * z = 1)>>;;

time real_qelim
 <<forall a b c d.
        (exists w x y z. (a * w + b * y = 1) /\
                         (a * x + b * z = 0) /\
                         (c * w + d * y = 0) /\
                         (c * x + d * z = 1))
        <=> ~(a * d = b * c)>>;;

*)
(* ========================================================================= *)
(* Geometry theorem proving.                                                 *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

(* ------------------------------------------------------------------------- *)
(* List of geometric properties with their coordinate translations.          *)
(* ------------------------------------------------------------------------- *)

let coordinations =
  ["collinear",
   <<(1_x - 2_x) * (2_y - 3_y) = (1_y - 2_y) * (2_x - 3_x)>>;
   "parallel",
    <<(1_x - 2_x) * (3_y - 4_y) = (1_y - 2_y) * (3_x - 4_x)>>;
   "perpendicular",
   <<(1_x - 2_x) * (3_x - 4_x) + (1_y - 2_y) * (3_y - 4_y) = 0>>;
   "lengths_eq",
   <<(1_x - 2_x)^2 + (1_y - 2_y)^2 = (3_x - 4_x)^2 + (3_y - 4_y)^2>>;
   "is_midpoint",
   <<2 * 1_x = 2_x + 3_x /\ 2 * 1_y = 2_y + 3_y>>;
   "is_intersection",
   <<(1_x - 2_x) * (2_y - 3_y) = (1_y - 2_y) * (2_x - 3_x) /\
     (1_x - 4_x) * (4_y - 5_y) = (1_y - 4_y) * (4_x - 5_x)>>;
   "=",<<(1_x = 2_x) /\ (1_y = 2_y)>>;
   "angles_eq",
   <<((2_y - 1_y) * (2_x - 3_x) - (2_y - 3_y) * (2_x - 1_x)) *
     ((5_x - 4_x) * (5_x - 6_x) + (5_y - 4_y) * (5_y - 6_y)) =
     ((5_y - 4_y) * (5_x - 6_x) - (5_y - 6_y) * (5_x - 4_x)) *
     ((2_x - 1_x) * (2_x - 3_x) + (2_y - 1_y) * (2_y - 3_y))>>];;

(* ------------------------------------------------------------------------- *)
(* Convert formula into coordinate form.                                     *)
(* ------------------------------------------------------------------------- *)

let inst_coord fms pat =
  let xtms,ytms = unzip
    (map (fun (Var v) -> Var(v^"_x"),Var(v^"_y")) fms) in
  let xs = map (fun n -> string_of_int n^"_x") (1--length fms)
  and ys = map (fun n -> string_of_int n^"_y") (1--length fms) in
  formsubst (instantiate (xs @ ys) (xtms @ ytms)) pat;;

let coordinate fm = onatoms
  (fun (R(a,args)) -> inst_coord args (assoc a coordinations)) fm;;

(* ------------------------------------------------------------------------- *)
(* Trivial example.                                                          *)
(* ------------------------------------------------------------------------- *)

(*
coordinate <<collinear(a,b,c) ==> collinear(b,a,c)>>;;
*)

(* ------------------------------------------------------------------------- *)
(* Verify equivalence under rotation.                                        *)
(* ------------------------------------------------------------------------- *)

let test_invariance(_,fm) =
  let modify s c x y =
    formsubst(instantiate [x;y]
              [Fn("-",[Fn("*",[c; Var x]); Fn("*",[s; Var y])]);
               Fn("+",[Fn("*",[c; Var y]); Fn("*",[s; Var x])])]) in
  let s = <<|s|>> and c = <<|c|>>
  and eq = <<s^2 + c^2 = 1>> in
  let fm' = itlist (fun n -> modify s c (n^"_x") (n^"_y"))
                   (map string_of_int (1--6)) fm in
  let equiv = Imp(eq,Iff(fm',fm)) in
  grobner_decide equiv;;

(*
forall test_invariance coordinations;;
*)

(* ------------------------------------------------------------------------- *)
(* And show we can always invent such a transformation to zero a y:          *)
(* ------------------------------------------------------------------------- *)

(*
real_qelim
 <<forall x y. exists s c. s^2 + c^2 = 1 /\ s * x + c * y = 0>>;;
*)

(* ------------------------------------------------------------------------- *)
(* Choose one point to be the origin and rotate to zero another x coordinate *)
(* ------------------------------------------------------------------------- *)

let originate fm =
  let a::b::ovs as vars = fv fm in
  let rfn = itlist (fun v -> v |-> Fn("0",[]))
                   [a^"_x"; a^"_y"; b^"_y"] undefined in
  formsubst rfn (coordinate fm);;

(* ------------------------------------------------------------------------- *)
(* Invariance under shearing, hence any affine xform, for many properties.   *)
(* ------------------------------------------------------------------------- *)

let test_str_invariance(_,fm) =
  let a = <<|a|>> and b = <<|b|>>
  and c = <<|c|>> and d = <<|d|>> in
  let modify x y =
    formsubst
      (x := Fn("+",[Fn("*",[Fn("1",[]); Var x]); Fn("*",[b; Var y])])) in
  let fm' = itlist (fun n -> modify (n^"_x") (n^"_y"))
                   (map string_of_int (1--6)) fm in
  let equiv = Iff(fm',fm) in
  grobner_decide equiv;;

(*
map (fun a -> fst a,test_str_invariance a) (butlast coordinations);;
*)

(* ------------------------------------------------------------------------- *)
(* Examples of inadequacy but fixability of complex coordinates.             *)
(* ------------------------------------------------------------------------- *)

(*
(grobner_decide ** originate)
 <<lengths_eq(A,X,B,X) /\ lengths_eq(B,X,C,X) /\
   lengths_eq(A,Y,B,Y) /\ lengths_eq(B,Y,C,Y) /\
   ~(A = B) /\ ~(A = C) /\ ~(B = C) ==> X = Y>>;;

(* ------------------------------------------------------------------------- *)
(* Centroid (Chou, example 142).                                             *)
(* ------------------------------------------------------------------------- *)

(grobner_decide ** originate)
 <<is_midpoint(d,b,c) /\
   is_midpoint(e,a,c) /\
   is_midpoint(f,a,b) /\
   is_intersection(m,b,e,a,d)
   ==> collinear(c,f,m)>>;;

(* ------------------------------------------------------------------------- *)
(* One from "Algorithms for Computer Algebra"                                *)
(* ------------------------------------------------------------------------- *)

(grobner_decide ** originate)
 <<is_midpoint(m,a,c) /\ perpendicular(a,c,m,b)
   ==> lengths_eq(a,b,b,c)>>;;

(* ------------------------------------------------------------------------- *)
(* Parallelogram theorem (Chou's expository example at the start).           *)
(* ------------------------------------------------------------------------- *)

(grobner_decide ** originate)
 <<parallel(a,b,d,c) /\ parallel(a,d,b,c) /\ is_intersection(e,a,c,b,d)
   ==> lengths_eq(a,e,e,c)>>;;

(grobner_decide ** originate)
 <<parallel(a,b,d,c) /\ parallel(a,d,b,c) /\
   is_intersection(e,a,c,b,d) /\ ~collinear(a,b,c)
   ==> lengths_eq(a,e,e,c)>>;;
*)

(* ------------------------------------------------------------------------- *)
(* Reduce p using triangular set, collecting degenerate conditions.          *)
(* ------------------------------------------------------------------------- *)

let rec pprove vars triang p degens =
  if p = Fn("0",[]) then degens else
  match triang with
    [] -> Atom(R("=",[p;Fn("0",[])]))::degens
  | (Fn("+",[c;Fn("*",[Var x;_])]) as q)::qs ->
        if x <> hd vars then
          if mem (hd vars) (fvt p)
          then itlist (pprove vars triang) (coefficients vars p) degens
          else pprove (tl vars) triang p degens
        else
          let k,p' = pdivide vars p q in
          if k = 0 then pprove vars qs p' degens else
          let degens' =
            Not(Atom(R("=",[head vars q; Fn("0",[])])))::degens in
          if is_constant vars p' then pprove vars qs p' degens' else
          itlist (pprove vars qs) (coefficients vars p') degens'
  | (q::qs) -> Not(Or(False,Atom(R("=",[q; Fn("0",[])]))))::degens;;

(* ------------------------------------------------------------------------- *)
(* Triangulate a set of polynomials.                                         *)
(* ------------------------------------------------------------------------- *)

let rec triangulate vars consts pols =
  if vars = [] then pols
  else if pols = [] then triangulate (tl vars) [] consts else
  let cns,tpols = partition (is_constant vars) pols in
  if cns <> [] then triangulate vars (cns @ consts) tpols else
  if length pols = 1 then pols @ triangulate (tl vars) [] consts else
  let n = end_itlist min (map (degree vars) pols) in
  let p = find (fun p -> degree vars p = n) pols in
  let ps = subtract pols [p] in
  if n = 1 then
    p :: (triangulate (tl vars) []
            (consts @ map (fun q -> snd(pdivide vars q p)) ps))
  else
    let m = end_itlist min (map (degree vars) ps) in
    let q = find (fun q -> degree vars q = m) ps in
    let qs = subtract ps [q] in
    let rs = p::(snd(pdivide vars q p))::qs in
    triangulate vars consts rs;;

(* ------------------------------------------------------------------------- *)
(* Auxiliary stuff.                                                          *)
(* ------------------------------------------------------------------------- *)

let dest_imp fm =
  match fm with
    Imp(p,q) -> p,q
  | _ -> failwith "dest_imp";;

let lhs eq = fst(dest_eq eq) and rhs eq = snd(dest_eq eq);;

(* ------------------------------------------------------------------------- *)
(* Trivial version of Wu's method based on repeated pseudo-division.         *)
(* ------------------------------------------------------------------------- *)

let wu fm vars zeros =
  let gfm0 = coordinate fm in
  let gfm = formsubst
    (itlist (fun v -> v |-> Fn("0",[])) zeros undefined) gfm0 in
  if not (set_eq vars (fv gfm))
  then failwith "wu: wrong variable set" else
  let ant,con = dest_imp gfm in
  let pols = map (lhs ** polyatom vars) (conjuncts ant)
  and ps = map (lhs ** polyatom vars) (conjuncts con) in
  let tri = triangulate vars [] pols in
  itlist (fun p -> union(pprove vars tri p [])) ps [];;

(* ------------------------------------------------------------------------- *)
(* Simson's theorem.                                                         *)
(* ------------------------------------------------------------------------- *)

(*
let simson =
 <<lengths_eq(o,a,o,b) /\
   lengths_eq(o,a,o,c) /\
   lengths_eq(o,a,o,d) /\
   collinear(e,b,c) /\
   collinear(f,a,c) /\
   collinear(g,a,b) /\
   perpendicular(b,c,d,e) /\
   perpendicular(a,c,d,f) /\
   perpendicular(a,b,d,g)
   ==> collinear(e,f,g)>>;;

let vars =
 ["g_y"; "g_x"; "f_y"; "f_x"; "e_y"; "e_x"; "d_y"; "d_x"; "c_y"; "c_x";
  "b_y"; "b_x"; "o_x"]
and zeros = ["a_x"; "a_y"; "o_y"];;

wu simson vars zeros;;

(* ------------------------------------------------------------------------- *)
(* Try without special coordinates.                                          *)
(* ------------------------------------------------------------------------- *)

wu simson (vars @ zeros) [];;

(* ------------------------------------------------------------------------- *)
(* Pappus (Chou's figure 6).                                                 *)
(* ------------------------------------------------------------------------- *)

let pappus =
 <<collinear(a1,b2,d) /\
   collinear(a2,b1,d) /\
   collinear(a2,b3,e) /\
   collinear(a3,b2,e) /\
   collinear(a1,b3,f) /\
   collinear(a3,b1,f)
   ==> collinear(d,e,f)>>;;

let vars = ["f_y"; "f_x"; "e_y"; "e_x"; "d_y"; "d_x";
            "b3_y"; "b2_y"; "b1_y"; "a3_x"; "a2_x"; "a1_x"]
and zeros = ["a1_y"; "a2_y"; "a3_y"; "b1_x"; "b2_x"; "b3_x"];;

wu pappus vars zeros;;

(* ------------------------------------------------------------------------- *)
(* Without special coordinates.                                              *)
(* ------------------------------------------------------------------------- *)

let pappus =
 <<collinear(a1,a2,a3) /\
   collinear(b1,b2,b3) /\
   collinear(a1,b2,d) /\
   collinear(a2,b1,d) /\
   collinear(a2,b3,e) /\
   collinear(a3,b2,e) /\
   collinear(a1,b3,f) /\
   collinear(a3,b1,f)
   ==> collinear(d,e,f)>>;;

wu pappus (vars @ zeros) [];;

*)
(* ========================================================================= *)
(* Implementation/proof of the Craig-Robinson interpolation theorem.         *)
(*                                                                           *)
(* This is based on the proof in Kreisel & Krivine, which works very nicely  *)
(* in our context.                                                           *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

(* ------------------------------------------------------------------------- *)
(* Interpolation for propositional logic.                                    *)
(* ------------------------------------------------------------------------- *)

let rec orify pvs fm =
  match pvs with
    [] -> fm
  | p::opvs ->
      orify opvs (Or(propsubst(p:=False) fm,propsubst(p:=True) fm));;

let pinterpolate p q =
  psimplify(orify (subtract (atoms p) (atoms q)) p);;

(* ------------------------------------------------------------------------- *)
(* Relation-symbol interpolation for universal closed formulas.              *)
(* ------------------------------------------------------------------------- *)

let urinterpolate p q =
  let fm = specialize(prenex(And(p,q))) in
  let fvs = fv fm and consts,funcs = herbfuns fm in
  let cntms = map (fun (c,_) -> Fn(c,[])) consts in
  let tups0 = dp_loop (clausal fm) cntms funcs fvs 0 [] [] [] in
  let tups = dp_refine_loop (clausal fm) cntms funcs fvs 0 [] [] [] in
  let fmis = map (fun tup -> formsubst (instantiate fvs tup) fm) tups in
  let ps,qs = unzip (map (fun (And(p,q)) -> p,q) fmis) in
  pinterpolate (list_conj(setify ps)) (list_conj(setify qs));;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)

(*
let p = prenex
 <<(forall x. R(x,f(x))) /\ (forall x y. S(x,y) <=> R(x,y) \/ R(y,x))>>
and q = prenex
 <<(forall x y z. S(x,y) /\ S(y,z) ==> T(x,z)) /\ ~T(0,0)>>;;

let c = urinterpolate p q;;

meson(Imp(p,c));;
meson(Imp(q,Not c));;
*)

(* ------------------------------------------------------------------------- *)
(* Pick the topmost terms starting with one of the given function symbols.   *)
(* ------------------------------------------------------------------------- *)

let rec toptermt fns tm =
  match tm with
    Var x -> []
  | Fn(f,args) -> if mem (f,length args) fns then [tm]
                  else itlist (union ** toptermt fns) args [];;

let topterms fns = atom_union
  (fun (R(p,args)) -> itlist (union ** toptermt fns) args []);;

(* ------------------------------------------------------------------------- *)
(* Interpolation for arbitrary universal formulas.                           *)
(* ------------------------------------------------------------------------- *)

let uinterpolate p q =
  let fp = functions p and fq = functions q in
  let rec simpinter tms n c =
    match tms with
      [] -> c
    | (Fn(f,args) as tm)::otms ->
        let v = "v_"^(string_of_int n) in
        let c' = replace (tm:=Var v) c in
        let c'' = if mem (f,length args) fp
                  then Exists(v,c') else Forall(v,c') in
        simpinter otms (n+1) c'' in
  let c = urinterpolate p q in
  let tts = topterms (union (subtract fp fq) (subtract fq fp)) c in
  let tms = sort (decreasing termsize) tts in
  simpinter tms 1 c;;

(* ------------------------------------------------------------------------- *)
(* The same example now gives a true interpolant.                            *)
(* ------------------------------------------------------------------------- *)

(*
let c = uinterpolate p q;;

meson(Imp(p,c));;
meson(Imp(q,Not c));;
*)

(* ------------------------------------------------------------------------- *)
(* Now lift to arbitrary formulas with no common free variables.             *)
(* ------------------------------------------------------------------------- *)

let cinterpolate p q =
  let fm = nnf(And(p,q)) in
  let efm = itlist (fun x p -> Exists(x,p)) (fv fm) fm
  and corr = map (fun (n,a) -> Fn(n,[]),False) (functions fm) in
  let And(p',q'),_ = skolem efm corr in
  uinterpolate p' q';;

(* ------------------------------------------------------------------------- *)
(* Now to completely arbitrary formulas.                                     *)
(* ------------------------------------------------------------------------- *)

let interpolate p q =
  let vs = map (fun v -> Var v) (intersect (fv p) (fv q))
  and fns = functions (And(p,q)) in
  let n = itlist (max_varindex "c_" ** fst) fns (Int 0) +/ Int 1 in
  let cs = map (fun i -> Fn("c_"^(string_of_num i),[]))
               (n---(n+/Int(length vs-1))) in
  let fn_vc = instantiate vs cs and fn_cv = instantiate cs vs in
  let p' = replace fn_vc p and q' = replace fn_vc q in
  replace fn_cv (cinterpolate p' q');;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)

(*
let p =
 <<(forall x. exists y. R(x,y)) /\
   (forall x y. S(v,x,y) <=> R(x,y) \/ R(y,x))>>
and q =
 <<(forall x y z. S(v,x,y) /\ S(v,y,z) ==> T(x,z)) /\
   (exists u. ~T(u,u))>>;;

let c = interpolate p q;;

meson(Imp(p,c));;
meson(Imp(q,Not c));;
*)

(* ------------------------------------------------------------------------- *)
(* Lift to logic with equality.                                              *)
(* ------------------------------------------------------------------------- *)

let einterpolate p q =
  let p' = equalitize p and q' = equalitize q in
  let p'' = if p' = p then p else And(fst(dest_imp p'),p)
  and q'' = if q' = q then q else And(fst(dest_imp q'),q) in
  interpolate p'' q'';;

(* ------------------------------------------------------------------------- *)
(* More examples, not in the text.                                           *)
(* ------------------------------------------------------------------------- *)

(*
let p = <<(p ==> q /\ r)>>
and q = <<~((q ==> p) ==> s ==> (p <=> q))>>;;

let c = interpolate p q;;

tautology(Imp(And(p,q),False));;

tautology(Imp(p,c));;
tautology(Imp(q,Not c));;

(* ------------------------------------------------------------------------- *)
(* A more interesting example.                                               *)
(* ------------------------------------------------------------------------- *)

let p = <<(forall x. exists y. R(x,y)) /\
          (forall x y. S(x,y) <=> R(x,y) \/ R(y,x))>>
and q = <<(forall x y z. S(x,y) /\ S(y,z) ==> T(x,z)) /\ ~T(u,u)>>;;

meson(Imp(And(p,q),False));;

let c = interpolate p q;;

meson(Imp(p,c));;
meson(Imp(q,Not c));;

(* ------------------------------------------------------------------------- *)
(* A variant where u is free in both parts.                                  *)
(* ------------------------------------------------------------------------- *)

let p = <<(forall x. exists y. R(x,y)) /\
          (forall x y. S(x,y) <=> R(x,y) \/ R(y,x)) /\
          (forall v. R(u,v) ==> Q(v,u))>>
and q = <<(forall x y z. S(x,y) /\ S(y,z) ==> T(x,z)) /\ ~T(u,u)>>;;

meson(Imp(And(p,q),False));;

let c = interpolate p q;;
meson(Imp(p,c));;
meson(Imp(q,Not c));;

(* ------------------------------------------------------------------------- *)
(* Way of generating examples quite easily (see K&K exercises).              *)
(* ------------------------------------------------------------------------- *)

let test_interp fm =
  let p = generalize(skolemize fm)
  and q = generalize(skolemize(Not fm)) in
  let c = interpolate p q in
  meson(Imp(And(p,q),False)); meson(Imp(p,c)); meson(Imp(q,Not c)); c;;

test_interp <<forall x. P(x) ==> exists y. forall z. P(z) ==> Q(y)>>;;

test_interp <<forall y. exists y. forall z. exists a.
                P(a,x,y,z) ==> P(x,y,z,a)>>;;

(* ------------------------------------------------------------------------- *)
(* Hintikka's examples.                                                      *)
(* ------------------------------------------------------------------------- *)

let p = <<forall x. L(x,b)>>
and q = <<(forall y. L(b,y) ==> m = y) /\ ~(m = b)>>;;

let c = einterpolate p q;;

meson(Imp(p,c));;
meson(Imp(q,Not c));;

let p =
 <<(forall x. A(x) /\ C(x) ==> B(x)) /\ (forall x. D(x) \/ ~D(x) ==> C(x))>>
and q =
 <<~(forall x. E(x) ==> A(x) ==> B(x))>>;;

let c = interpolate p q;;
meson(Imp(p,c));;
meson(Imp(q,Not c));;
*)
(* ========================================================================= *)
(* Nelson-Oppen combined decision procedure.                                 *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

(* ------------------------------------------------------------------------- *)
(* Real language with decision procedure.                                    *)
(* ------------------------------------------------------------------------- *)

let real_lang =
  let fn = ["-",1; "+",2; "-",2; "*",2; "^",2]
  and pr = ["<=",2; "<",2; ">=",2; ">",2] in
  (fun (s,n) -> n = 0 & is_numeral(Fn(s,[])) or mem (s,n) fn),
  (fun sn -> mem sn pr),
  (fun fm -> real_qelim(generalize fm) = True);;

(* ------------------------------------------------------------------------- *)
(* Integer language with decision procedure.                                 *)
(* ------------------------------------------------------------------------- *)

let int_lang =
  let fn = ["-",1; "+",2; "-",2; "*",2]
  and pr = ["<=",2; "<",2; ">=",2; ">",2] in
  (fun (s,n) -> n = 0 & is_numeral(Fn(s,[])) or mem (s,n) fn),
  (fun sn -> mem sn pr),
  (fun fm -> integer_qelim(generalize fm) = True);;

(* ------------------------------------------------------------------------- *)
(* Add any uninterpreted symbols to a list of languages.                     *)
(* ------------------------------------------------------------------------- *)

let add_default langs =
  langs @ [(fun sn -> not (exists (fun (f,p,d) -> f sn) langs)),
           (fun sn -> not (exists (fun (f,p,d) -> p sn) langs)),
           ccvalid];;

(* ------------------------------------------------------------------------- *)
(* Choose a language for homogenization of an atom.                          *)
(* ------------------------------------------------------------------------- *)

let chooselang langs fm =
  match fm with
    Atom(R("=",[Fn(f,args);_])) | Atom(R("=",[_;Fn(f,args)])) ->
        find (fun (fn,pr,dp) -> fn(f,length args)) langs
  | Atom(R(p,args)) ->
        find (fun (fn,pr,dp) -> pr(p,length args)) langs;;

(* ------------------------------------------------------------------------- *)
(* Make a variable.                                                          *)
(* ------------------------------------------------------------------------- *)

let mkvar n = Var("v_"^(string_of_num n));;

(* ------------------------------------------------------------------------- *)
(* General listification for CPS-style function.                             *)
(* ------------------------------------------------------------------------- *)

let rec listify f l cont =
  match l with
    [] -> cont []
  | h::t -> f h (fun h' -> listify f t (fun t' -> cont(h'::t')));;

(* ------------------------------------------------------------------------- *)
(* Homogenize a term.                                                        *)
(* ------------------------------------------------------------------------- *)

let rec homot (fn,pr,dp) tm cont n defs =
  match tm with
    Var x -> cont tm n defs
  | Fn(f,args) ->
       if fn(f,length args) then
       listify (homot (fn,pr,dp)) args (fun a -> cont (Fn(f,a))) n defs
       else cont (mkvar n) (n +/ Int 1)
                 (Atom(R("=",[mkvar n;tm]))::defs);;

(* ------------------------------------------------------------------------- *)
(* Homogenize a literal.                                                     *)
(* ------------------------------------------------------------------------- *)

let rec homol langs fm cont n defs =
  match fm with
    Not(f) -> homol langs f (fun p -> cont(Not(p))) n defs
  | Atom(R(p,args)) ->
        let lang = chooselang langs fm in
        listify (homot lang) args (fun a -> cont (Atom(R(p,a)))) n defs
  | _ -> failwith "homol: not a literal";;

(* ------------------------------------------------------------------------- *)
(* Fully homogenize a list of literals.                                      *)
(* ------------------------------------------------------------------------- *)

let rec homo langs fms cont =
  listify (homol langs) fms
          (fun dun n defs ->
              if defs = [] then cont dun n defs
              else homo langs defs (fun res -> cont (dun@res)) n []);;

(* ------------------------------------------------------------------------- *)
(* Overall homogenization.                                                   *)
(* ------------------------------------------------------------------------- *)

let homogenize langs fms =
  let fvs = unions(map fv fms) in
  let n = Int 1 +/ itlist (max_varindex "v_") fvs (Int 0) in
  homo langs fms (fun res n defs -> res) n [];;

(* ------------------------------------------------------------------------- *)
(* Whether a formula belongs to a language.                                  *)
(* ------------------------------------------------------------------------- *)

let belongs (fn,pr,dp) fm =
  forall fn (functions fm) &
  forall pr (subtract (predicates fm) ["=",2]);;

(* ------------------------------------------------------------------------- *)
(* Partition formulas among a list of languages.                             *)
(* ------------------------------------------------------------------------- *)

let rec langpartition langs fms =
  match langs with
    [] -> if fms = [] then [] else failwith "langpartition"
  | l::ls -> let fms1,fms2 = partition (belongs l) fms in
             fms1::langpartition ls fms2;;

(* ------------------------------------------------------------------------- *)
(* Turn an arrangement (partition) of variables into corresponding formula.  *)
(* ------------------------------------------------------------------------- *)

let rec arreq l =
  match l with
    v1::v2::rest -> mk_eq (Var v1) (Var v2) :: (arreq (v2::rest))
  | _ -> [];;

let arrangement part =
  itlist (union ** arreq) part
         (map (fun (v,w) -> Not(mk_eq (Var v) (Var w)))
              (distinctpairs (map hd part)));;

(* ------------------------------------------------------------------------- *)
(* Attempt to substitute with trivial equations.                             *)
(* ------------------------------------------------------------------------- *)

let dest_def fm =
  match fm with
    Atom(R("=",[Var x;t])) when not(mem x (fvt t)) -> x,t
  | Atom(R("=",[t; Var x])) when not(mem x (fvt t)) -> x,t
  | _ -> failwith "dest_def";;

let rec redeqs eqs =
  try let eq = find (can dest_def) eqs in
      let x,t = dest_def eq in
      redeqs (map (formsubst (x := t)) (subtract eqs [eq]))
  with Failure _ -> eqs;;

(* ------------------------------------------------------------------------- *)
(* Naive Nelson-Oppen variant trying all arrangements.                       *)
(* ------------------------------------------------------------------------- *)

let trydps ldseps fms =
  exists (fun ((_,_,dp),fms0) -> dp(Not(list_conj(redeqs(fms0 @ fms)))))
         ldseps;;

let nelop_refute vars ldseps =
  forall (trydps ldseps ** arrangement) (allpartitions vars);;

let nelop1 langs fms0 =
  let fms = homogenize langs fms0 in
  let seps = langpartition langs fms in
  let fvlist = map (unions ** map fv) seps in
  let vars = filter (fun x -> length (filter (mem x) fvlist) >= 2)
                    (unions fvlist) in
  nelop_refute vars (zip langs seps);;

let nelop langs fm = forall (nelop1 langs) (simpdnf(simplify(Not fm)));;

(* ------------------------------------------------------------------------- *)
(* Check that our example works.                                             *)
(* ------------------------------------------------------------------------- *)

(*
nelop (add_default [int_lang])
 <<f(v - 1) - 1 = v + 1 /\ f(u) + 1 = u - 1 /\ u + 1 = v ==> false>>;;

(* ------------------------------------------------------------------------- *)
(* Take note of our case explosion.                                          *)
(* ------------------------------------------------------------------------- *)

let bell n = length(allpartitions (1--n));;
map bell (1--10);;
*)

(* ------------------------------------------------------------------------- *)
(* Find the smallest subset satisfying a predicate.                          *)
(* ------------------------------------------------------------------------- *)

let rec findasubset p m l =
  if m = 0 then p [] else
  match l with
    [] -> failwith "findasubset"
  | h::t -> try findasubset (fun s -> p(h::s)) (m - 1) t
            with Failure _ -> findasubset p m t;;

let findsubset p l =
  tryfind (fun n ->
    findasubset (fun x -> if p x then x else failwith "") n l)
       (0--length l);;

(* ------------------------------------------------------------------------- *)
(* The "true" Nelson-Oppen method.                                           *)
(* ------------------------------------------------------------------------- *)

let rec nelop_refute eqs ldseps =
  try let dj = findsubset (trydps ldseps ** map negate) eqs in
      forall (fun eq ->
        nelop_refute (subtract eqs [eq])
                     (map (fun (dps,es) -> (dps,eq::es)) ldseps)) dj
  with Failure _ -> false;;

let nelop1 langs fms0 =
  let fms = homogenize langs fms0 in
  let seps = langpartition langs fms in
  let fvlist = map (unions ** map fv) seps in
  let vars = filter (fun x -> length (filter (mem x) fvlist) >= 2)
                    (unions fvlist) in
  let eqs = map (fun (a,b) -> mk_eq (Var a) (Var b))
                (distinctpairs vars) in
  nelop_refute eqs (zip langs seps);;

let nelop langs fm = forall (nelop1 langs) (simpdnf(simplify(Not fm)));;

(* ------------------------------------------------------------------------- *)
(* Some additional examples (from ICS paper and Shostak's "A practical..."   *)
(* ------------------------------------------------------------------------- *)

(*
nelop (add_default [int_lang])
 <<y <= x /\ y >= x + z /\ z >= 0 ==> f(f(x) - f(y)) = f(z)>>;;

nelop (add_default [int_lang])
 <<x = y /\ y >= z /\ z >= x ==> f(z) = f(x)>>;;

nelop (add_default [int_lang])
 <<a <= b /\ b <= f(a) /\ f(a) <= 1
  ==> a + b <= 1 \/ b + f(b) <= 1 \/ f(f(b)) <= f(a)>>;;

(* ------------------------------------------------------------------------- *)
(* Confirmation of non-convexity.                                            *)
(* ------------------------------------------------------------------------- *)

map (real_qelim ** generalize)
  [<<x * y = 0 /\ z = 0 ==> x = z \/ y = z>>;
   <<x * y = 0 /\ z = 0 ==> x = z>>;
   <<x * y = 0 /\ z = 0 ==> y = z>>];;

map (integer_qelim ** generalize)
  [<<0 <= x /\ x < 2 /\ y = 0 /\ z = 1 ==> x = y \/ x = z>>;
   <<0 <= x /\ x < 2 /\ y = 0 /\ z = 1 ==> x = y>>;
   <<0 <= x /\ x < 2 /\ y = 0 /\ z = 1 ==> x = z>>];;

(* ------------------------------------------------------------------------- *)
(* Failures of original Shostak procedure.                                   *)
(* ------------------------------------------------------------------------- *)

nelop (add_default [int_lang])
 <<f(v - 1) - 1 = v + 1 /\ f(u) + 1 = u - 1 /\ u + 1 = v ==> false>>;;

nelop (add_default [int_lang])
 <<f(v) = v /\ f(u) = u - 1 /\ u = v ==> false>>;;

(* ------------------------------------------------------------------------- *)
(* Additional examples.                                                      *)
(* ------------------------------------------------------------------------- *)

time (nelop (add_default [int_lang]))
 <<z = f(x - y) /\ x = z + y /\ ~(-(y) = -(x - f(f(z)))) ==> false>>;;

time (nelop (add_default [int_lang]))
 <<(x = y /\ z = 1 ==> f(f((x+z))) = f(f((1+y))))>>;;

time (nelop (add_default [int_lang]))
 <<hd(x) = hd(y) /\ tl(x) = tl(y) /\ ~(x = nil) /\ ~(y = nil)
   ==> f(x) = f(y)>>;;

time (nelop (add_default [int_lang]))
 <<~(f(f(x) - f(y)) = f(z)) /\ y <= x /\ y >= x + z /\ z >= 0 ==> false>>;;

time (nelop (add_default [int_lang]))
 <<x < f(y) + 1 /\ f(y) <= x ==> (P(x,y) <=> P(f(y),y))>>;;

time (nelop (add_default [int_lang]))
 <<(x >= y ==> MAX(x,y) = x) /\ (y >= x ==> MAX(x,y) = y)
   ==> x = y + 2 ==> MAX(x,y) = x>>;;

time (nelop (add_default [int_lang]))
 <<x <= g(x) /\ x >= g(x) ==> x = g(g(g(g(x))))>>;;

time (nelop (add_default [real_lang]))
 <<x^2 =  1 ==> (f(x) = f(-(x)))  ==> (f(x) = f(1))>>;;

time (nelop (add_default [int_lang]))
 <<2 * f(x + y) = 3 * y /\ 2 * x = y ==> f(f(x + y)) = 3 * x>>;;

*)
(* ========================================================================= *)
(* Finite state transition systems.                                          *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

let default_parser = parsep;;

(* ------------------------------------------------------------------------- *)
(* Transition relation for modulo-5 counter.                                 *)
(* ------------------------------------------------------------------------- *)

(*
let counter_trans =
 <<(v0' <=> ~v0 /\ ~v2) /\
   (v1' <=> ~(v0 <=> v1)) /\
   (v2' <=> v0 /\ v1)>>;;

(* ------------------------------------------------------------------------- *)
(* Transition relation for incorrect mutex algorithm.                        *)
(* ------------------------------------------------------------------------- *)

let mutex_trans =
 <<(q2' <=> q2) /\ (q1' <=> q1) /\ (q0' <=> q0) /\
  (~p2 /\ ~p1 /\ ~p0 /\ ~v1 /\ ~v0
   ==> ~p2' /\ ~p1' /\ p0' /\ ~v1' /\ ~v0') /\
  (~p2 /\ ~p1 /\ ~p0 /\ (v1 \/ v0)
   ==> ~p2' /\ ~p1' /\ ~p0' /\ (v1' <=> v1) /\ (v0' <=> v0)) /\
  (~p2 /\ ~p1 /\ p0
   ==> ~p2' /\ p1' /\ ~p0' /\ ~v1' /\ v0') /\
  (~p2 /\ p1 /\ ~p0
   ==> ~p2' /\ p1' /\ p0' /\ (v1' <=> v1) /\ (v0' <=> v0)) /\
  (~p2 /\ p1 /\ p0
   ==> p2' /\ ~p1' /\ ~p0' /\ ~v1' /\ ~v0') /\
  (p2 /\ ~p1 /\ ~p0
   ==> ~p2' /\ ~p1' /\ ~p0' /\ (v1' <=> v1) /\ (v0' <=> v0)) \/
  (p2' <=> p2) /\ (p1' <=> p1) /\ (p0' <=> p0) /\
  (~q2 /\ ~q1 /\ ~q0 /\ ~v1 /\ ~v0
   ==> ~q2' /\ ~q1' /\ q0' /\ ~v1' /\ ~v0') /\
  (~q2 /\ ~q1 /\ ~q0 /\ (v1 \/ v0)
   ==> ~q2' /\ ~q1' /\ ~q0' /\ (v1' <=> v1) /\ (v0' <=> v0)) /\
  (~q2 /\ ~q1 /\ q0
   ==> ~q2' /\ q1' /\ ~q0' /\ v1' /\ ~v0') /\
  (~q2 /\ q1 /\ ~q0
   ==> ~q2' /\ q1' /\ q0' /\ (v1' <=> v1) /\ (v0' <=> v0)) /\
  (~q2 /\ q1 /\ q0
   ==> q2' /\ ~q1' /\ ~q0' /\ ~v1' /\ ~v0') /\
  (q2 /\ ~q1 /\ ~q0
   ==> ~q2' /\ ~q1' /\ ~q0' /\ (v1' <=> v1) /\ (v0' <=> v0))>>;;

(* ------------------------------------------------------------------------- *)
(* Same for Peterson's algorithm.                                            *)
(* ------------------------------------------------------------------------- *)

let peter_trans =
 <<(q2' <=> q2) /\ (q1' <=> q1) /\ (q0' <=> q0) /\
  (~p2 /\ ~p1 /\ ~p0
   ==> ~p2' /\ ~p1' /\ p0' /\ f1' /\ (f2' <=> f2) /\ (t' <=> t)) /\
  (~p2 /\ ~p1 /\ p0
   ==> ~p2' /\ p1' /\ ~p0' /\ (f1' <=> f1) /\ (f2' <=> f2) /\ t') /\
  (~p2 /\ p1 /\ ~p0 /\ f2
   ==> ~p2' /\ p1' /\ p0' /\ f2' /\ (f1' <=> f1) /\ (t' <=> t)) /\
  (~p2 /\ p1 /\ ~p0 /\ ~f2
   ==> p2' /\ ~p1' /\ ~p0' /\ ~f2' /\ (f1' <=> f1) /\ (t' <=> t)) /\
  (~p2 /\ p1 /\ p0 /\ t
   ==> ~p2' /\ p1' /\ ~p0' /\ t' /\ (f1' <=> f1) /\ (f2' <=> f2)) /\
  (~p2 /\ p1 /\ p0 /\ ~t
   ==> p2' /\ ~p1' /\ ~p0' /\ ~t' /\ (f1' <=> f1) /\ (f2' <=> f2)) /\
  (p2 /\ ~p1 /\ ~p0
   ==> p2' /\ ~p1' /\ p0' /\
       (f1' <=> f1) /\ (f2' <=> f2) /\ (t' <=> t)) /\
  (p2 /\ ~p1 /\ p0
   ==> p2' /\ p1' /\ ~p0' /\ ~f1' /\ (f2' <=> f2) /\ (t' <=> t)) /\
  (p2 /\ p1 /\ ~p0
   ==> ~p2' /\ ~p1' /\ ~p0' /\
       (f1' <=> f1) /\ (f2' <=> f2) /\ (t' <=> t)) \/
  (p2' <=> p2) /\ (p1' <=> p1) /\ (p0' <=> p0) /\
  (~q2 /\ ~q1 /\ ~q0
   ==> ~q2' /\ ~q1' /\ q0' /\ f2' /\ (f1' <=> f1) /\ (t' <=> t)) /\
  (~q2 /\ ~q1 /\ q0
   ==> ~q2' /\ q1' /\ ~q0' /\ (f1' <=> f1) /\ (f2' <=> f2) /\ ~t') /\
  (~q2 /\ q1 /\ ~q0 /\ f1
   ==> ~q2' /\ q1' /\ q0' /\ f1' /\ (f2' <=> f2) /\ (t' <=> t)) /\
  (~q2 /\ q1 /\ ~q0 /\ ~f1
   ==> q2' /\ ~q1' /\ ~q0' /\ ~f1' /\ (f2' <=> f2) /\ (t' <=> t)) /\
  (~q2 /\ q1 /\ q0 /\ ~t
   ==> ~q2' /\ q1' /\ ~q0' /\ ~t' /\ (f1' <=> f1) /\ (f2' <=> f2)) /\
  (~q2 /\ q1 /\ q0 /\ t
   ==> q2' /\ ~q1' /\ ~q0' /\ t' /\ (f1' <=> f1) /\ (f2' <=> f2)) /\
  (q2 /\ ~q1 /\ ~q0
   ==> q2' /\ ~q1' /\ q0' /\
       (f1' <=> f1) /\ (f2' <=> f2) /\ (t' <=> t)) /\
  (q2 /\ ~q1 /\ q0
   ==> q2' /\ q1' /\ ~q0' /\ ~f2' /\ (f1' <=> f1) /\ (t' <=> t)) /\
  (q2 /\ q1 /\ ~q0
   ==> ~q2' /\ ~q1' /\ ~q0' /\
       (f1' <=> f1) /\ (f2' <=> f2) /\ (t' <=> t))>>;;

(* ------------------------------------------------------------------------- *)
(* Example of "induction" method for reachability.                           *)
(* ------------------------------------------------------------------------- *)

tautology(Imp(counter_trans,<<~(v0 /\ v2) ==> ~(v0' /\ v2')>>));;
*)

(* ------------------------------------------------------------------------- *)
(* Useful combinators for applying functions maintaining bdd state.          *)
(* ------------------------------------------------------------------------- *)

let single f bst x fn = let bst1,x' = f bst x in fn bst1 x';;

let double f bst x y fn =
  let bst1,x' = f bst x in let bst2,y' = f bst1 y in fn bst2 (x',y');;

(* ------------------------------------------------------------------------- *)
(* More uniform BDD operations all with BDD and two computed tables.         *)
(* ------------------------------------------------------------------------- *)

let bdd_And (bdd,acomp,pcomp) (m,n) =
  let (bdd',acomp'),p = bdd_and (bdd,acomp) (m,n) in
  (bdd',acomp',pcomp),p;;

let bdd_Node s (bdd,acomp,pcomp as bst) (l,r) =
  let bdd',n = mk_node bdd (s,l,r) in (bdd',acomp,pcomp),n;;

let bdd_Make (bdd,acomp,pcomp) fm =
  let (bdd',acomp'),n = bddify [] (bdd,acomp) fm in
  (bdd',acomp',pcomp),n;;

(* ------------------------------------------------------------------------- *)
(* Iterative version of bdd_Make for a list of formulas.                     *)
(* ------------------------------------------------------------------------- *)

let rec bdd_Makes bst fms =
  match fms with
    [] -> bst,[]
  | fm::ofms -> let bst1,n = bdd_Make bst fm in
                let bst2,ns = bdd_Makes bst1 ofms in bst2,n::ns;;

(* ------------------------------------------------------------------------- *)
(* Derived BDD logical operations.                                           *)
(* ------------------------------------------------------------------------- *)

let bdd_Or bst (m,n) =
  let bst',p = bdd_And bst (-m,-n) in bst',-p;;

let bdd_Imp bst (m,n) = bdd_Or bst (-m,n);;

let bdd_Not bst n = (bst,-n);;

let bdd_Iff bst (m,n) = double bdd_Imp bst (m,n) (n,m) bdd_And;;

(* ------------------------------------------------------------------------- *)
(* Combined "Pre" operation, doing relational product and priming second     *)
(* BDD's variables at the same time.                                         *)
(*                                                                           *)
(* Given arguments vs', r[vs,vs'] and p[vs], this produces the BDD for       *)
(*                                                                           *)
(*         exists vs'. r[vs,vs'] /\ p[vs']                                   *)
(*                                                                           *)
(* We must have the same relative orders of primed and unprimed variables!   *)
(* ------------------------------------------------------------------------- *)

let rec bdd_Pre evs (bdd,acomp,pcomp as bst) (m1,m2) =
  if m1 = -1 or m2 = -1 then bst,-1
  else if m1 = 1 then bst,1 else
  try bst,apply pcomp (m1,m2) with Failure _ ->
  let (s1,l1,r1) = expand_node bdd m1
  and (s0,l2,r2) = expand_node bdd m2 in
  let s0' = P(pname s0^"'") in
  let (s,lpair,rpair) =
      if s1 = s0' then s1,(l1,l2),(r1,r2)
      else if s0 = P "" or order bdd s1 s0' then s1,(l1,m2),(r1,m2)
      else s0',(m1,l2),(m1,r2) in
  let bdd_orex = if mem s evs then bdd_Or else bdd_Node s in
  let (bdd',acomp',pcomp'),n =
    double (bdd_Pre evs) bst lpair rpair bdd_orex in
  (bdd',acomp',((m1,m2) |-> n) pcomp'),n;;

(* ------------------------------------------------------------------------- *)
(* Iterate a BDD operation till a fixpoint is reached.                       *)
(* ------------------------------------------------------------------------- *)

let rec iterate_to_fixpoint f bst n =
  let bst',n' = f bst n in
  if n' = n then bst',n' else iterate_to_fixpoint f bst' n';;

(* ------------------------------------------------------------------------- *)
(* Model-check EF(p) by iterating a |-> p \/ Pre(a).                         *)
(* ------------------------------------------------------------------------- *)

let step_EF evs r p bst a =
  let bst',a' = bdd_Pre evs bst (r,a) in bdd_Or bst' (p,a');;

let check_EF evs r bst p =
  iterate_to_fixpoint (step_EF evs r p) bst (-1);;

(* ------------------------------------------------------------------------- *)
(* Simple reachability. (Can we get from s to p via relation r?)             *)
(* ------------------------------------------------------------------------- *)

let reachable vars s r p =
  let vars' = map (fun s -> P(s^"'")) vars in
  let bst0 = mk_bdd (fun s1 s2 -> s1 < s2),undefined,undefined in
  let bst1,[n_s;n_r;n_p] = bdd_Makes bst0 [s;r;p] in
  let bst2,n_f = check_EF vars' n_r bst1 n_p in
  snd(bdd_And bst2 (n_s,n_f)) <> -1;;

(* ------------------------------------------------------------------------- *)
(* Examples.                                                                 *)
(* ------------------------------------------------------------------------- *)

(*
reachable ["v2"; "v1"; "v0"]
  <<true>> counter_trans <<v2 /\ v1>>;;

reachable ["v2"; "v1"; "v0"]
  <<~v2 /\ ~v1 /\ ~v0>> counter_trans <<v2 /\ v1>>;;

reachable ["p2"; "p1"; "p0"; "q2"; "q1"; "q0"; "v1"; "v0"]
  <<~p2 /\ ~p1 /\ ~p0 /\ ~q2 /\ ~q1 /\ ~q0 /\ ~v1 /\ ~v0>>
  mutex_trans
  <<~p2 /\ p1 /\ ~p0 /\ ~q2 /\ q1 /\ ~q0>>;;

reachable ["p2"; "p1"; "p0"; "q2"; "q1"; "q0"; "f2"; "f1"; "t"]
  <<~p2 /\ ~p1 /\ ~p0 /\ ~q2 /\ ~q1 /\ ~q0>>
  peter_trans
  <<p2 /\ ~p1 /\ ~p0 /\ q2 /\ ~q1 /\ ~q0>>;;
*)
(* ========================================================================= *)
(* Temporal logic.                                                           *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

type tform = Falset
           | Truet
           | Propvart of string
           | Nott of tform
           | Andt of tform * tform
           | Ort of tform * tform
           | Impt of tform * tform
           | Ifft of tform * tform
           | Next of tform
           | Box of tform
           | Diamond of tform;;

(* ------------------------------------------------------------------------- *)
(* Basic semantics for arbitrary valuation-sequence.                         *)
(* ------------------------------------------------------------------------- *)

let rec teval fm v =
  match fm with
    Falset -> false
  | Truet -> true
  | Propvart(x) -> v 0 x
  | Nott(p) -> not(teval p v)
  | Andt(p,q) -> (teval p v) & (teval q v)
  | Ort(p,q) -> (teval p v) or (teval q v)
  | Impt(p,q) -> not(teval p v) or (teval q v)
  | Ifft(p,q) -> (teval p v) = (teval q v)
  | Next p -> teval p (fun i -> v(i + 1))
  | Box p -> teval p v & teval p (fun i -> v(i + 1))
  | Diamond p -> teval p v or teval p (fun i -> v(i + 1));;

(* ------------------------------------------------------------------------- *)
(* Proof via first order reduction.                                          *)
(* ------------------------------------------------------------------------- *)

let default_parser = parse;;

(*
meson
 <<~(forall t'. t <= t' ==> p(t)) <=> exists t'. t <= t' /\ ~p(t)>>;;

meson
 <<(forall t. t <= t)
  ==> (forall t'. t <= t' ==> forall t''. t' <= t'' ==> p(t''))
      ==> forall t'. t <= t' ==> p(t')>>;;

meson
 <<(forall s t u. s <= t /\ t <= u ==> s <= u)
  ==> (forall t'. t <= t' ==> p(t'))
      ==>  (forall t'. t <= t' ==> forall t''. t' <= t'' ==> p(t''))>>;;
*)
(* ========================================================================= *)
(* CTL model checking.                                                       *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

let default_parser = parsep;;

(* ------------------------------------------------------------------------- *)
(* Define CTL syntax (with the path and temporal operators combined).        *)
(* ------------------------------------------------------------------------- *)

type sform = Falsec
           | Truec
           | Propvarc of string
           | Notc of sform
           | Andc of sform * sform
           | Orc of sform * sform
           | Impc of sform * sform
           | Iffc of sform * sform
           | AF of sform
           | AG of sform
           | AX of sform
           | AU of sform * sform
           | EF of sform
           | EG of sform
           | EX of sform
           | EU of sform * sform;;

(* ------------------------------------------------------------------------- *)
(* Model-check EX(p) by simply mapping a |-> Pre(a).                         *)
(* ------------------------------------------------------------------------- *)

let check_EX evs r bst p = bdd_Pre evs bst (r,p);;

(* ------------------------------------------------------------------------- *)
(* Model-check E(p U q) by iterating a |-> q \/ p /\ Pre(a) from "false".    *)
(* ------------------------------------------------------------------------- *)

let step_EU evs r p q bst a =
  let bst1,a' = bdd_Pre evs bst (r,a) in
  let bst2,pa' = bdd_And bst1 (p,a') in
  bdd_Or bst2 (q,pa');;

let check_EU evs r bst (p,q) =
  iterate_to_fixpoint (step_EU evs r p q) bst (-1);;

(* ------------------------------------------------------------------------- *)
(* Model-check EG p by iterating a |-> p /\ Pre(a) from "true".              *)
(* ------------------------------------------------------------------------- *)

let step_EG evs r p bst a =
  let bst',a' = bdd_Pre evs bst (r,a) in bdd_And bst' (p,a');;

let check_EG evs r bst p =
  iterate_to_fixpoint (step_EG evs r p) bst 1;;

(* ------------------------------------------------------------------------- *)
(* Main symbolic model checking function.                                    *)
(* ------------------------------------------------------------------------- *)

let rec modelcheck vars r bst fm =
  match fm with
    Falsec -> bst,-1
  | Truec -> bst,1
  | Propvarc(s) -> bdd_Node (P s) bst (1,-1)
  | Notc(p) -> single (modelcheck vars r) bst p bdd_Not
  | Andc(p,q) -> double (modelcheck vars r) bst p q bdd_And
  | Orc(p,q) -> double (modelcheck vars r) bst p q bdd_Or
  | Impc(p,q) -> double (modelcheck vars r) bst p q bdd_Imp
  | Iffc(p,q) -> double (modelcheck vars r) bst p q bdd_Iff
  | AF(p) -> modelcheck vars r bst (Notc(EG(Notc p)))
  | AG(p) -> modelcheck vars r bst (Notc(EF(Notc p)))
  | AX(p) -> modelcheck vars r bst (Notc(EX(Notc p)))
  | AU(p,q) -> modelcheck vars r bst
               (Andc(Notc(EU(Notc(q),Andc(Notc(p),Notc(q)))),
                     Notc(EG(Notc(q)))))
  | EF(p) -> modelcheck vars r bst (EU(Truec,p))
  | EG(p) -> single (modelcheck vars r) bst p (check_EG vars r)
  | EX(p) -> single (modelcheck vars r) bst p (check_EX vars r)
  | EU(p,q) -> double (modelcheck vars r) bst p q (check_EU vars r);;

(* ------------------------------------------------------------------------- *)
(* Overall model-checking function.                                          *)
(* ------------------------------------------------------------------------- *)

let model_check vars s r p =
  let vars' = map (fun s -> P(s^"'")) vars in
  let bst0 = mk_bdd (fun s1 s2 -> s1 < s2),undefined,undefined in
  let bst1,[n_s;n_r] = bdd_Makes bst0 [s;r] in
  let bst2,n_f = modelcheck vars' n_r bst1 p in
  snd(bdd_Imp bst2 (n_s,n_f)) = 1;;

(* ------------------------------------------------------------------------- *)
(* Some simple examples.                                                     *)
(* ------------------------------------------------------------------------- *)

(*
let [v0; v1; v2; p0; p1; p2] = map (fun s -> Propvarc(s))
    ["v0"; "v1"; "v2"; "p0"; "p1"; "p2"];;

model_check ["v2"; "v1"; "v0"] <<true>> counter_trans
  (AF(Andc(v1,AX(Notc(v1)))));;

let s = <<~p2 /\ ~p1 /\ ~p0 /\ ~q2 /\ ~q1 /\ ~q0 /\ ~v1 /\ ~v0>>
and fm = AG(Impc(Andc(Notc(p0),Andc(Notc(p1),Notc(p2))),
                EF(Andc(p0,Andc(Notc(p1),Notc(p2))))))
and vars = ["p2"; "p1"; "p0"; "q2"; "q1"; "q0"; "v1"; "v0"] in
model_check vars s mutex_trans fm;;

(* ------------------------------------------------------------------------- *)
(* Failure of fairness even for correct algorithm.                           *)
(* ------------------------------------------------------------------------- *)

let s =
  <<~p2 /\ ~p1 /\ ~p0 /\ ~q2 /\ ~q1 /\ ~q0 /\ ~f1 /\ ~f0 /\ ~t>>
and fm = AG(Impc(Andc(Notc(p0),Andc(Notc(p1),Notc(p2))),
                AF(Andc(p0,Andc(Notc(p1),Notc(p2))))))
and vars = ["p2"; "p1"; "p0"; "q2"; "q1"; "q0"; "f2"; "f1"; "t"] in
model_check vars s peter_trans fm;;
*)

(* ------------------------------------------------------------------------- *)
(* Model checking with fairness.                                             *)
(* ------------------------------------------------------------------------- *)

let rec stepfair_EG evs r fcs p bst a =
  match fcs with
    [] -> bdd_And bst (p,a)
  | f::ofcs ->
        let bst1,af = bdd_And bst (a,f) in
        let bst2,puaf = check_EU evs r bst1 (p,af) in
        let bst3,pru = bdd_Pre evs bst2 (r,puaf) in
        let bst4,a' = bdd_And bst3 (a,pru) in
        stepfair_EG evs r ofcs p bst4 a';;

let checkfair_EG evs r fcs bst p =
  if fcs = [] then iterate_to_fixpoint (step_EG evs r p) bst 1
  else iterate_to_fixpoint (stepfair_EG evs r fcs p) bst 1;;

let checkfair_EX evs r fcs bst p =
  let bst1,fairs = checkfair_EG evs r fcs bst 1 in
  let bst2,pfairs = bdd_And bst1 (p,fairs) in
  check_EX evs r bst2 pfairs;;

let checkfair_EU evs r fcs bst (p,q) =
  let bst1,fairs = checkfair_EG evs r fcs bst 1 in
  let bst2,qfairs = bdd_And bst1 (q,fairs) in
  check_EU evs r bst2 (p,qfairs);;

let rec fmodelcheck vars r fcs bst fm =
  match fm with
    Falsec -> bst,-1
  | Truec -> bst,1
  | Propvarc(s) -> bdd_Node (P s) bst (1,-1)
  | Notc(p) -> single (fmodelcheck vars r fcs) bst p bdd_Not
  | Andc(p,q) -> double (fmodelcheck vars r fcs) bst p q bdd_And
  | Orc(p,q) -> double (fmodelcheck vars r fcs) bst p q bdd_Or
  | Impc(p,q) -> double (fmodelcheck vars r fcs) bst p q bdd_Imp
  | Iffc(p,q) -> double (fmodelcheck vars r fcs) bst p q bdd_Iff
  | AF(p) -> fmodelcheck vars r fcs bst (Notc(EG(Notc p)))
  | AG(p) -> fmodelcheck vars r fcs bst (Notc(EF(Notc p)))
  | AX(p) -> fmodelcheck vars r fcs bst (Notc(EX(Notc p)))
  | AU(p,q) -> fmodelcheck vars r fcs bst
               (Andc(Notc(EU(Notc(q),Andc(Notc(p),Notc(q)))),
                     Notc(EG(Notc(q)))))
  | EF(p) -> fmodelcheck vars r fcs bst (EU(Truec,p))
  | EG(p) ->
      single (fmodelcheck vars r fcs) bst p (checkfair_EG vars r fcs)
  | EX(p) ->
      single (fmodelcheck vars r fcs) bst p (checkfair_EX vars r fcs)
  | EU(p,q) ->
      double (fmodelcheck vars r fcs) bst p q (checkfair_EU vars r fcs);;

(* ------------------------------------------------------------------------- *)
(* Overall packaging.                                                        *)
(* ------------------------------------------------------------------------- *)

let fair_model_check vars s r p fcs =
  let vars' = map (fun s -> P(s^"'")) vars in
  let bst0 = mk_bdd (fun s1 s2 -> s1 < s2),undefined,undefined in
  let bst1,n_s::n_r::n_fcs = bdd_Makes bst0 (s::r::fcs) in
  let bst2,n_f = fmodelcheck vars' n_r n_fcs bst1 p in
  snd(bdd_Imp bst2 (n_s,n_f)) = 1;;
(* ========================================================================= *)
(* LTL decision procedure based on reduction to fair CTL model checking.     *)
(*                                                                           *)
(* Basically follows Clarke et al's "Another look at LTL model checking"     *)
(* paper, though it's presented in a somewhat different style.               *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

(* ------------------------------------------------------------------------- *)
(* Prime all propositional variables in a term to perform "next" op.         *)
(* ------------------------------------------------------------------------- *)

let next fm =
  let ifn p = p |-> Atom(P(pname p^"'")) in
  propsubst (itlist ifn (atoms fm) undefined) fm;;

(* ------------------------------------------------------------------------- *)
(* Transform the formula to fair CTL model checking.                         *)
(* ------------------------------------------------------------------------- *)

let rec ltl (fm,defs,fcs,n as quad) =
  match fm with
    Nott(p) -> let p',defs',fcs',n' = ltl(p,defs,fcs,n) in
               Not(p'),defs',fcs',n'
  | Andt(p,q) -> ltl2 (fun (p,q) -> And(p,q)) (p,q) quad
  | Ort(p,q) -> ltl2 (fun (p,q) -> Or(p,q)) (p,q) quad
  | Impt(p,q) -> ltl2 (fun (p,q) -> Imp(p,q)) (p,q) quad
  | Ifft(p,q) -> ltl2 (fun (p,q) -> Iff(p,q)) (p,q) quad
  | Next(p) -> ltl1 p (fun p v -> next p) (fun p v -> []) quad
  | Box(p) ->
      ltl1 p (fun p v -> And(p,next v)) (fun p v -> [Imp(p,v)]) quad
  | Diamond(p) ->
      ltl1 p (fun p v -> Or(p,next v)) (fun p v -> [Imp(v,p)]) quad
  | Falset -> False,defs,fcs,n
  | Truet -> True,defs,fcs,n
  | Propvart(p) -> Atom(P p),defs,fcs,n

and ltl1 p cons1 cons2 (fm,defs,fcs,n) =
  let p',defs',fcs',n' = ltl(p,defs,fcs,n) in
  let v,n'' = mkprop n' in
  v,(Iff(v,cons1 p' v)::defs'),(cons2 p' v @ fcs'),(n' +/ Int 1)

and ltl2 cons (p,q) (fm,defs,fcs,n) =
  let fm1,defs1,fcs1,n1 = ltl (p,defs,fcs,n) in
  let fm2,defs2,fcs2,n2 = ltl (q,defs1,fcs1,n1) in
  cons(fm1,fm2),defs2,fcs2,n2;;

(* ------------------------------------------------------------------------- *)
(* Iterator analogous to "overatoms" for propositional logic.                *)
(* ------------------------------------------------------------------------- *)

let rec itpropt f fm a =
  match fm with
    Propvart(x) -> f x a
  | Nott(p) | Next(p) | Box(p) | Diamond(p) -> itpropt f p a
  | Andt(p,q) | Ort(p,q) | Impt(p,q) | Ifft(p,q) ->
        itpropt f p (itpropt f q a)
  | _ -> a;;

(* ------------------------------------------------------------------------- *)
(* Get propositional variables in a temporal formula.                        *)
(* ------------------------------------------------------------------------- *)

let propvarst fm = setify(itpropt (fun h t -> h::t) fm []);;

(* ------------------------------------------------------------------------- *)
(* We also need to avoid primed variables.                                   *)
(* ------------------------------------------------------------------------- *)

let max_varindex' pfx =
  let mkf = max_varindex pfx in
  fun s n ->
    if s = "" then n else
    let n' = mkf s n and l = String.length s - 1 in
    if String.sub s l 1 <> "'" then n' else mkf (String.sub s 0 l) n';;

(* ------------------------------------------------------------------------- *)
(* Make a variable name "p_n".                                               *)
(* ------------------------------------------------------------------------- *)

let mkname n = "p_"^(string_of_num n);;

(* ------------------------------------------------------------------------- *)
(* Overall LTL decision procedure (we add box to make sure top is variable). *)
(* ------------------------------------------------------------------------- *)

let ltldecide fm =
  let n = Int 1 +/ itpropt (max_varindex' "p_") fm (Int 0) in
  let Atom(P p),defs,fcs,m = ltl(Box fm,[],[],n) in
  let vars = propvarst fm @ map mkname (n---(m-/Int 1)) in
  fair_model_check vars True (list_conj defs) (AG(Propvarc(p))) fcs;;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)

(*
let fm = let p = Propvart "p" in Impt(Next(Box p),Diamond(p));;

ltldecide fm;;
*)

(* ------------------------------------------------------------------------- *)
(* Alternative version moving p into the starting states.                    *)
(* ------------------------------------------------------------------------- *)

let ltldecide' fm =
  let n = Int 1 +/ itpropt (max_varindex' "p_") fm (Int 0) in
  let p,defs,fcs,m = ltl(fm,[],[],n) in
  let vars = propvarst fm @ map mkname (n---(m-/Int 1)) in
  fair_model_check vars (Not p) (list_conj defs) (Notc(EG(Truec))) fcs;;

(* ------------------------------------------------------------------------- *)
(* A parser, just to make testing nicer.                                     *)
(* ------------------------------------------------------------------------- *)

let rec parse_tformula inp =
   parse_right_infix "<=>" (fun (p,q) -> Ifft(p,q))
     (parse_right_infix "==>" (fun (p,q) -> Impt(p,q))
         (parse_right_infix "\\/" (fun (p,q) -> Ort(p,q))
             (parse_right_infix "/\\" (fun (p,q) -> Andt(p,q))
                 parse_tunary))) inp

and parse_tunary inp =
  match inp with
    "~"::onp -> papply (fun e -> Nott(e)) (parse_tunary onp)
  | "("::")"::onp -> papply (fun e -> Next(e)) (parse_tunary onp)
  | "["::"]"::onp -> papply (fun e -> Box(e)) (parse_tunary onp)
  | "<>"::onp -> papply (fun e -> Diamond(e)) (parse_tunary onp)
  | _ -> parse_tatom inp

and parse_tatom inp =
  match inp with
    [] -> failwith "Expected an expression at end of input"
  | "false"::toks -> Falset,toks
  | "true"::toks -> Truet,toks
  | "("::toks -> parse_bracketed parse_tformula ")" toks
  | p::toks -> Propvart(p),toks;;

let parsel s =
  let toks,rest = parse_tformula(lex(explode s)) in
  if rest = [] then toks else failwith "Unparsed input";;

(* ------------------------------------------------------------------------- *)
(* Examples.                                                                 *)
(* ------------------------------------------------------------------------- *)

let default_parser = parsel;;

(*
ltldecide << <>[]p ==> []<>()<>[]p >>;;

ltldecide' << [](p ==> ()p) ==> [](p ==> []p) >>;;

ltldecide' << []<>p ==> <>[]p >>;;

(* ------------------------------------------------------------------------- *)
(* Compare performances (and check results!) on test cases.                  *)
(* ------------------------------------------------------------------------- *)

let test fm =
  let a = time ltldecide fm in
  let a' = time ltldecide' fm in
  if a = a' then a else failwith("*** Disparity");;

test << (()[]p ==> <>p) >>;;

test << [] (()([]p) ==> <>p) >>;;

test << <>p ==> ()<>p >>;;

test << ()<>p ==> <>p >>;;

test << <>[]p ==> <>[]p >>;;

test << <>[]p ==> []<>p >>;;

test << ()(p /\ q) <=> () p /\ () q >>;;

test << [](p /\ q) <=> [] p /\ [] q >>;;

test << <>(p /\ q) <=> <> p /\ <> q >>;;

test << <>(p /\ q) ==> <> p /\ <> q >>;;

test << [](p ==> ()p) ==> [](p ==> []p) >>;;

test << [](p ==> ()p) ==> p ==> []p >>;;

test << [](p ==> ()p) ==> []p >>;;

test << [](p ==> ()q) /\ [](q ==> ()p) ==> []<>p >>;;

test << [](p ==> ()q) /\ [](q ==> ()p) ==> <>p ==> []<>p >>;;

test << <>(<>p) <=> <>p >>;;

test << [][]p <=> []p >>;;

test << ()[]p <=> []()p >>;;

test << []p ==> <>p >>;;

test << []p ==> ()p >>;;

test << ()p ==> <>p >>;;

test << [](p ==> ()p) ==> ()p ==> p ==> [] p >>;;

test << ~[]p <=> <>(~p) >>;;

test << []p ==> p >>;;

test << [](p ==> []p) ==> (p <=> []p) >>;;

test << []([]p ==> []q) <=> []([]p ==> q) >>;;

test << <>[]p ==> ()()()()<>[]p >>;;

test << <>[]p ==> ()()()()()()()()()<>[]p >>;;

test << <>[]<>p ==> []<>p >>;;

test << <>[]<>p ==> []<>[]<>p >>;;

test << ()[]p ==> []p >>;;

*)
(* ========================================================================= *)
(* Symbolic trajectory evaluation (STE).                                     *)
(*                                                                           *)
(* Based on Melham-Darbari presentation and John O'Leary's tutorial code.    *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

let default_parser = parsep;;

(* ------------------------------------------------------------------------- *)
(* Quaternary lattice.                                                       *)
(* ------------------------------------------------------------------------- *)

type quat = X | ZERO | ONE | T;;

(* ------------------------------------------------------------------------- *)
(* Basic lattice operations.                                                 *)
(* ------------------------------------------------------------------------- *)

let (<==) x y = x = X or y = T or x = y;;

let (&&) x y = if x <== y then y else if y <== x then x else T;;

(* ------------------------------------------------------------------------- *)
(* Boolean extensions.                                                       *)
(* ------------------------------------------------------------------------- *)

let bools q =
  match q with
    X -> [false; true]
  | ZERO -> [false]
  | ONE -> [true]
  | T -> [];;

(* ------------------------------------------------------------------------- *)
(* Converse.                                                                 *)
(* ------------------------------------------------------------------------- *)

let rec quat s =
  match s with
    [false; true] -> X
  | [false]       -> ZERO
  | [true]        -> ONE
  | []            -> T
  | _ -> quat(setify s);;

(* ------------------------------------------------------------------------- *)
(* Deduce the ternary or quaternary extensions of operations.                *)
(* ------------------------------------------------------------------------- *)

let print_quattable quaf fm =
  let pvs = atoms fm in
  let width = itlist (max ** String.length ** pname) pvs 5 + 1 in
  let fixw s = s^String.make(width - String.length s) ' ' in
  let truthstring =
    function X -> "X" | ZERO -> "0" | ONE -> "1" | T -> "T" in
  let testqs assig =
    let assigs = itlist (allpairs (fun h t -> h::t)) assig [[]] in
    quat(map (eval fm ** apply ** instantiate pvs) assigs) in
  let mk_row v =
     let ufn = instantiate pvs v in
     let lis = map (fixw ** truthstring ** apply ufn) pvs
     and ans = fixw(truthstring(testqs(map bools v))) in
     print_string(itlist (^) lis ("| "^ans)); print_newline() in
  let separator = String.make (width * length pvs + 9) '-' in
  print_string(itlist (fun s t -> fixw(pname s) ^ t) pvs "| formula");
  print_newline(); print_string separator; print_newline();
  let lis = if quaf then [X; ZERO; ONE; T] else [X; ZERO; ONE] in
  do_list mk_row (alltuples (length pvs) lis);;

(* ------------------------------------------------------------------------- *)
(* Examples.                                                                 *)
(* ------------------------------------------------------------------------- *)

(*
print_quattable true <<~p>>;;

print_quattable false <<p /\ q>>;;

print_quattable false <<p \/ q>>;;

print_quattable false <<p ==> q>>;;

print_quattable false <<p <=> q>>;;

print_quattable false <<~(p <=> q)>>;;

(* ------------------------------------------------------------------------- *)
(* Example of pessimism from composing truth tables.                         *)
(* ------------------------------------------------------------------------- *)

print_quattable false <<p /\ ~p>>;;
*)

(* ------------------------------------------------------------------------- *)
(* Choice operator for Boolean parametrization.                              *)
(* ------------------------------------------------------------------------- *)

let (>->) b x = if b then x else X;;

(* ------------------------------------------------------------------------- *)
(* Spurious abstraction, but it might be useful later.                       *)
(* ------------------------------------------------------------------------- *)

type node = Node of string;;

(* ------------------------------------------------------------------------- *)
(* Type of trajectory formulas.                                              *)
(* ------------------------------------------------------------------------- *)

type trajform = Is_0 of node
              | Is_1 of node
              | Andj of trajform * trajform
              | When of trajform * prop formula
              | Next of trajform;;

(* ------------------------------------------------------------------------- *)
(* Abstract formula semantics with propositional valuation as last argument. *)
(* ------------------------------------------------------------------------- *)

let rec tholds tf seq v =
  match tf with
    Is_0 nd -> ZERO <== seq 0 nd v
  | Is_1 nd -> ONE <== seq 0 nd v
  | Andj(tf1,tf2) -> tholds tf1 seq v & tholds tf2 seq v
  | When(tf1,p) -> eval p v --> tholds tf1 seq v
  | Next(tf1) -> tholds tf1 (fun t -> seq(t + 1)) v;;

let rec defseq tf t nd v =
  match tf with
    Is_0 n -> (n = nd & t = 0) >-> ZERO
  | Is_1 n -> (n = nd & t = 0) >-> ONE
  | Andj(tf1,tf2) -> defseq tf1 t nd v && defseq tf2 t nd v
  | When(tf1,p) -> eval p v >-> defseq tf1 t nd v
  | Next(tf1) -> (t <> 0) >-> defseq tf1 (t - 1) nd v;;

let rec deftraj step tf t nd v =
  if t = 0 then defseq tf t nd v
  else defseq tf t nd v && step(deftraj step tf (t - 1)) nd v;;

(* ------------------------------------------------------------------------- *)
(* Depth of a trajectory formula.                                            *)
(* ------------------------------------------------------------------------- *)

let rec timedepth tf =
  match tf with
    Is_0 _ | Is_1 _ -> 0
  | Andj(tf1,tf2) -> max (timedepth tf1) (timedepth tf2)
  | When(tf1,p) -> timedepth tf1
  | Next(tf1) -> timedepth tf1 + 1;;

(* ------------------------------------------------------------------------- *)
(* Reformulation that will work better when we use finite partial functions. *)
(* ------------------------------------------------------------------------- *)

let constrain a1 a2 y f x1 x2 =
   if x1 = a1 & x2 = a2 then y && f x1 x2 else f x1 x2;;

let defseq =
  let rec defseq t0 g tf v seq =
    match tf with
      Is_0 n -> constrain t0 n (g >-> ZERO) seq
    | Is_1 n -> constrain t0 n (g >-> ONE) seq
    | Andj(tf1,tf2) -> defseq t0 g tf2 v (defseq t0 g tf1 v seq)
    | When(tf1,p) -> defseq t0 (eval p v & g) tf1 v seq
    | Next(tf1) -> defseq (t0 + 1) g tf1 v seq in
  fun tf t nd v -> defseq 0 true tf v (fun t n -> X) t nd;;

(* ------------------------------------------------------------------------- *)
(* The dual-rail encoding.                                                   *)
(* ------------------------------------------------------------------------- *)

let top = (-1,-1)
and one = (1,-1)
and zero = (-1,1)
and bot = (1,1);;

(* ------------------------------------------------------------------------- *)
(* Lattice ordering as a BDD operation.                                      *)
(* ------------------------------------------------------------------------- *)

let leq bst (h1,l1) (h2,l2) =
  let bst1,h = bdd_Imp bst (h2,h1) in
  let bst2,l = bdd_Imp bst1 (l2,l1) in
  bdd_And bst2 (h,l);;

(* ------------------------------------------------------------------------- *)
(* The lattice join as a BDD operation.                                      *)
(* ------------------------------------------------------------------------- *)

let join bst (h1,l1) (h2,l2) =
  let bst1,h = bdd_And bst (h1,h2) in
  let bst2,l = bdd_And bst1 (l1,l2) in
  bst2,(h,l);;

(* ------------------------------------------------------------------------- *)
(* Choice as a BDD operation.                                                *)
(* ------------------------------------------------------------------------- *)

let bchoice bst b (h,l) =
  let bst1,h' = bdd_Imp bst (b,h) in
  let bst2,l' = bdd_Imp bst1 (b,l) in
  bst2,(h',l');;

(* ------------------------------------------------------------------------- *)
(* Form the defining sequence.                                               *)
(* ------------------------------------------------------------------------- *)

let constrain bst t n y seq =
  let st = tryapplyd seq t undefined in
  let x = tryapplyd st n bot in
  let bst1,z = join bst x y in
  bst1,(t |-> (n |-> z) st) seq;;

let defseq =
  let rec defseq bst t0 g tf seq =
    match tf with
      Is_0 n -> let bst1,z = bchoice bst g zero in
                constrain bst1 t0 n z seq
    | Is_1 n -> let bst1,z = bchoice bst g one in
                constrain bst1 t0 n z seq
    | Andj(tf1,tf2) ->
        let bst1,seq1 = defseq bst t0 g tf1 seq in
        defseq bst1 t0 g tf2 seq1
    | When(tf1,p) ->
        let bst1,n = bdd_Make bst p in
        let bst2,g' = bdd_And bst1 (n,g) in
        defseq bst2 t0 g' tf1 seq
    | Next(tf1) -> defseq bst (t0 + 1) g tf1 seq in
  fun bst tf -> defseq bst 0 1 tf undefined;;

(* ------------------------------------------------------------------------- *)
(* Now the defining trajectory.                                              *)
(* ------------------------------------------------------------------------- *)

let rec deftraj bst step tf t =
  if t = 0 then defseq bst tf else
  let bst1,seq1 = deftraj bst step tf (t - 1) in
  let st = tryapplyd seq1 (t - 1) undefined in
  let bst2,st' = step bst1 st in
  itlist (fun (n,v) (bst,seq) -> constrain bst t n v seq) (funset st')
         (bst2,seq1);;

(* ------------------------------------------------------------------------- *)
(* Check containment of sequences.                                           *)
(* ------------------------------------------------------------------------- *)

let contained bst seq1 seq2 =
  itlist (fun t ->
    let st1 = apply seq1 t and st2 = tryapplyd seq2 t undefined in
    itlist (fun n (bst,x) ->
              let v1 = apply st1 n and v2 = tryapplyd st2 n bot in
              let bst1,y = leq bst v1 v2 in bdd_And bst1 (x,y))
           (dom st1))
    (dom seq1) (bst,1);;

(* ------------------------------------------------------------------------- *)
(* STE model checking algorithm.                                             *)
(* ------------------------------------------------------------------------- *)

let ste bst ckt (a,c) =
  let bst1,a_trj = deftraj bst ckt a (timedepth c) in
  let bst2,c_seq = defseq bst1 c in
  contained bst2 c_seq a_trj;;

(* ------------------------------------------------------------------------- *)
(* Basic gates. Note that they don't work for overconstrained value T.       *)
(* ------------------------------------------------------------------------- *)

let not_gate bst (h,l) = bst,(l,h);;

let and_gate bst (h1,l1) (h2,l2) =
  let bst1,h = bdd_And bst (h1,h2) in
  let bst2,l = bdd_Or bst1 (l1,l2) in
  bst2,(h,l);;

let or_gate bst (h1,l1) (h2,l2) =
  let bst1,h = bdd_Or bst (h1,h2) in
  let bst2,l = bdd_And bst1 (l1,l2) in
  bst2,(h,l);;

(* ------------------------------------------------------------------------- *)
(* The next-state function.                                                  *)
(* ------------------------------------------------------------------------- *)

let step bst st =
  let value nd = tryapplyd st (Node nd) bot in
  itlist (fun ((x1,x2),y) (bst,st) ->
            let bst',z = and_gate bst (value x1) (value x2) in
            if z = bot then (bst,st) else (bst',(Node y |-> z) st))
         [("a1","b0"),"b1";
          ("a2","b1"),"b2";
          ("a3","b2"),"b3";
          ("a4","b3"),"b4";
          ("a5","b4"),"b5";
          ("a6","b5"),"b6"]
         (bst,(Node"b0" |-> value "a0") undefined);;

(* ------------------------------------------------------------------------- *)
(* Shorthands (really, it would be better to write a proper parser).         *)
(* ------------------------------------------------------------------------- *)

let (&&&) x y = Andj(x,y);;

let (<--) f p = When(f,p);;

let is0 s = Is_0(Node s) and is1 s = Is_1(Node s);;

let next n p = funpow n (fun p -> Next p) p;;

(* ------------------------------------------------------------------------- *)
(* An example.                                                               *)
(* ------------------------------------------------------------------------- *)

(*
let a =
  (next 0 (is0 "a0") <-- <<~p /\ ~q /\ ~r>>) &&&
  (next 1 (is0 "a1") <-- <<~p /\ ~q /\  r>>) &&&
  (next 2 (is0 "a2") <-- <<~p /\  q /\ ~r>>) &&&
  (next 3 (is0 "a3") <-- <<~p /\  q /\  r>>) &&&
  (next 4 (is0 "a4") <-- << p /\ ~q /\ ~r>>) &&&
  (next 5 (is0 "a5") <-- << p /\ ~q /\  r>>) &&&
  (next 6 (is0 "a6") <-- << p /\  q /\ ~r>>) &&&
  (next 0 (is1 "a0") <-- << p /\  q /\  r>>) &&&
  (next 1 (is1 "a1") <-- << p /\  q /\  r>>) &&&
  (next 2 (is1 "a2") <-- << p /\  q /\  r>>) &&&
  (next 3 (is1 "a3") <-- << p /\  q /\  r>>) &&&
  (next 4 (is1 "a4") <-- << p /\  q /\  r>>) &&&
  (next 5 (is1 "a5") <-- << p /\  q /\  r>>) &&&
  (next 6 (is1 "a6") <-- << p /\  q /\  r>>);;

let c = (next 7 (is1 "b6") <-- <<p /\ q /\ r>>) &&&
        (next 7 (is0 "b6") <-- <<~p \/ ~q \/ ~r>>);;

let bst = mk_bdd (fun s1 s2 -> s1 < s2),undefined,undefined in
ste bst step (a,c);;
*)
(* ========================================================================= *)
(* Simple example of LCF-style prover: equational logic via Birkhoff rules.  *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

let default_parser = parse;;

(* ------------------------------------------------------------------------- *)
(* LCF realization of Birkhoff-style rules for equational logic.             *)
(* ------------------------------------------------------------------------- *)

module type Birkhoff =
   sig type thm
       val axiom : fol formula -> thm
       val inst : (string, term) func -> thm -> thm
       val refl : term -> thm
       val sym : thm -> thm
       val trans : thm -> thm -> thm
       val cong : string -> thm list -> thm
       val dest_thm : thm -> fol formula list * fol formula
   end;;

module Proveneq : Birkhoff =
  struct
    type thm = fol formula list * fol formula
    let axiom p =
      match p with
        Atom(R("=",[s;t])) -> ([p],p)
      | _ -> failwith "axiom: not an equation"
    let inst i (asm,p) = (asm,formsubst i p)
    let refl t = ([],Atom(R("=",[t;t])))
    let sym (asm,Atom(R("=",[s;t]))) = (asm,Atom(R("=",[t;s])))
    let trans (asm1,Atom(R("=",[s;t]))) (asm2,Atom(R("=",[t';u]))) =
      if t' = t then (union asm1 asm2,Atom(R("=",[s;u])))
      else failwith "trans: theorems don't match up"
    let cong f ths =
      let asms,eqs =
        unzip(map (fun (asm,Atom(R("=",[s;t]))) -> asm,(s,t)) ths) in
      let ls,rs = unzip eqs in
      (unions asms,Atom(R("=",[Fn(f,ls);Fn(f,rs)])))
    let dest_thm th = th
  end;;

(* ------------------------------------------------------------------------- *)
(* Printer.                                                                  *)
(* ------------------------------------------------------------------------- *)

open Proveneq;;

let print_thm th =
  let asl,c = dest_thm th in
  open_box 0;
  if asl = [] then () else
  (print_formula print_atom 0 (hd asl);
   do_list (fun a -> print_string ","; print_space();
                     print_formula print_atom 0 a)
           (tl asl));
  print_space(); print_string "|- ";
  open_box 0; print_formula print_atom 0 c; close_box();
  close_box();;

(*
#install_printer print_thm;;
*)

(* ------------------------------------------------------------------------- *)
(* Using it to do a group theory example "manually".                         *)
(* ------------------------------------------------------------------------- *)

(*
let group_1 = axiom <<x * (y * z) = (x * y) * z>>;;
let group_2 = axiom <<1 * x = x>>;;
let group_3 = axiom <<i(x) * x = 1>>;;

let th1 = inst ("x" := <<|x * i(x)|>>) (sym group_2)
and th2 = cong "*" [inst ("x" := <<|i(x)|>>) (sym group_3);
                    refl <<|x * i(x)|>>]
and th3 = inst (instantiate ["x"; "y"; "z"]
                   [<<|i(i(x))|>>; <<|i(x)|>>; <<|x * i(x)|>>])
               (sym group_1)
and th4 =
  trans (inst (instantiate ["x"; "y"; "z"]
                   [<<|i(x)|>>; <<|x|>>; <<|i(x)|>>])
              group_1)
        (trans (cong "*" [group_3; refl <<|i(x)|>>])
               (inst ("x" := <<|i(x)|>>) group_2))
and th5 = inst ("x" := <<|i(x)|>>) group_3 in
end_itlist trans
 [th1; th2; th3; cong "*" [refl <<|i(i(x))|>>; th4]; th5];;
*)

(* ------------------------------------------------------------------------- *)
(* Trivial example of a derived rule.                                        *)
(* ------------------------------------------------------------------------- *)

let lcong t th = cong "*" [th; refl t];;

let rcong t th = cong "*" [refl t; th];;

(* ------------------------------------------------------------------------- *)
(* Rewriting derived rule.                                                   *)
(* ------------------------------------------------------------------------- *)

let conclusion th = snd(dest_thm th);;

let rewrite1_conv eq t =
  match conclusion eq with
    Atom(R("=",[l;r])) -> inst (term_match l t) eq
  | _ -> failwith "rewrite1_conv";;

let thenc conv1 conv2 t =
  let th1 = conv1 t in
  let th2 = conv2 (rhs(conclusion th1)) in
  trans th1 th2;;

let rec depth fn tm =
  try (thenc fn (depth fn)) tm with Failure _ ->
  match tm with
    Var x -> refl tm
  | Fn(f,args) -> let th = cong f (map (depth fn) args) in
                  if rhs(conclusion th) = tm then th
                  else trans th (depth fn (rhs(conclusion th)));;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)

(*
depth (rewrite1_conv group_1) <<|(a * b * c) * (d * e) * f|>>;;
*)
(* ========================================================================= *)
(* LCF-style prover for Tarski-style Hilbert system of first order logic.    *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

(* ------------------------------------------------------------------------- *)
(* Basic first order deductive system.                                       *)
(*                                                                           *)
(* This is based on Tarski's trick for avoiding use of a substitution        *)
(* primitive. It seems about the simplest possible system we could use.      *)
(*                                                                           *)
(*  |- p ==> (q ==> p)                                                       *)
(*  |- (p ==> q ==> r) ==> (p ==> q) ==> (p ==> r)                           *)
(*  |- ((p ==> false) ==> false) ==> p                                       *)
(*  |- (forall x. p ==> q) ==> (forall x. p) ==> (forall x. q)               *)
(*  |- p ==> forall x. p                            [x not free in p]        *)
(*  |- (exists x. x = t)                            [x not free in t]        *)
(*  |- t = t                                                                 *)
(*  |- s1 = t1 ==> ... ==> sn = tn ==> f(s1,..,sn) = f(t1,..,tn)             *)
(*  |- s1 = t1 ==> ... ==> sn = tn ==> P(s1,..,sn) ==> P(t1,..,tn)           *)
(*  |- (p <=> q) ==> p ==> q                                                 *)
(*  |- (p <=> q) ==> q ==> p                                                 *)
(*  |- (p ==> q) ==> (q ==> p) ==> (p <=> q)                                 *)
(*  |- true <=> (false ==> false)                                            *)
(*  |- ~p <=> (p ==> false)                                                  *)
(*  |- p \/ q <=> ~(~p /\ ~q)                                                *)
(*  |- p /\ q <=> (p ==> q ==> false) ==> false                              *)
(*  |- (exists x. p) <=> ~(forall x. ~p)                                     *)
(*  if |- p ==> q and |- p then |- q                                         *)
(*  if |- p then |- forall x. p                                              *)
(* ------------------------------------------------------------------------- *)

module type Proofsystem =
   sig type thm
       val axiom_addimp : fol formula -> fol formula -> thm
       val axiom_distribimp :
            fol formula -> fol formula -> fol formula -> thm
       val axiom_doubleneg : fol formula -> thm
       val axiom_allimp : string -> fol formula -> fol formula -> thm
       val axiom_impall : string -> fol formula -> thm
       val axiom_existseq : string -> term -> thm
       val axiom_eqrefl : term -> thm
       val axiom_funcong : string -> term list -> term list -> thm
       val axiom_predcong : string -> term list -> term list -> thm
       val axiom_iffimp1 : fol formula -> fol formula -> thm
       val axiom_iffimp2 : fol formula -> fol formula -> thm
       val axiom_impiff : fol formula -> fol formula -> thm
       val axiom_true : thm
       val axiom_not : fol formula -> thm
       val axiom_or : fol formula -> fol formula -> thm
       val axiom_and : fol formula -> fol formula -> thm
       val axiom_exists : string -> fol formula -> thm
       val modusponens : thm -> thm -> thm
       val gen : string -> thm -> thm
       val concl : thm -> fol formula
   end;;

(* ------------------------------------------------------------------------- *)
(* Auxiliary functions.                                                      *)
(* ------------------------------------------------------------------------- *)

let rec occurs_in s t =
  s = t or
  match t with
    Var y -> false
  | Fn(f,args) -> exists (occurs_in s) args;;

let rec free_in t fm =
  match fm with
    False -> false
  | True -> false
  | Atom(R(p,args)) -> exists (occurs_in t) args
  | Not(p) -> free_in t p
  | And(p,q) -> free_in t p or free_in t q
  | Or(p,q) -> free_in t p or free_in t q
  | Imp(p,q) -> free_in t p or free_in t q
  | Iff(p,q) -> free_in t p or free_in t q
  | Forall(y,p) -> not (occurs_in (Var y) t) & free_in t p
  | Exists(y,p) -> not (occurs_in (Var y) t) & free_in t p;;

(* ------------------------------------------------------------------------- *)
(* Implementation of the abstract data type of theorems.                     *)
(* ------------------------------------------------------------------------- *)

module Proven : Proofsystem =
  struct
    type thm = fol formula
    let axiom_addimp p q = Imp(p,Imp(q,p))
    let axiom_distribimp p q r =
      Imp(Imp(p,Imp(q,r)),Imp(Imp(p,q),Imp(p,r)))
    let axiom_doubleneg p = Imp(Imp(Imp(p,False),False),p)
    let axiom_allimp x p q =
      Imp(Forall(x,Imp(p,q)),Imp(Forall(x,p),Forall(x,q)))
    let axiom_impall x p =
      if not (free_in (Var x) p) then Imp(p,Forall(x,p))
      else failwith "axiom_impall: variable free in formula"
    let axiom_existseq x t =
      if not (occurs_in (Var x) t) then Exists(x,mk_eq (Var x) t)
      else failwith "axiom_existseq: variable free in term"
    let axiom_eqrefl t = mk_eq t t
    let axiom_funcong f lefts rights =
       itlist2 (fun s t p -> Imp(mk_eq s t,p)) lefts rights
               (mk_eq (Fn(f,lefts)) (Fn(f,rights)))
    let axiom_predcong p lefts rights =
       itlist2 (fun s t p -> Imp(mk_eq s t,p)) lefts rights
               (Imp(Atom(R(p,lefts)),Atom(R(p,rights))))
    let axiom_iffimp1 p q = Imp(Iff(p,q),Imp(p,q))
    let axiom_iffimp2 p q = Imp(Iff(p,q),Imp(q,p))
    let axiom_impiff p q = Imp(Imp(p,q),Imp(Imp(q,p),Iff(p,q)))
    let axiom_true = Iff(True,Imp(False,False))
    let axiom_not p = Iff(Not p,Imp(p,False))
    let axiom_or p q = Iff(Or(p,q),Not(And(Not(p),Not(q))))
    let axiom_and p q = Iff(And(p,q),Imp(Imp(p,Imp(q,False)),False))
    let axiom_exists x p = Iff(Exists(x,p),Not(Forall(x,Not p)))
    let modusponens pq p =
      match pq with
        Imp(p',q) ->
             if p = p' then q else failwith "modusponens: no matchup"
      | _ -> failwith "modusponens: not an implication"
    let gen x p = Forall(x,p)
    let concl c = c
  end;;

(* ------------------------------------------------------------------------- *)
(* A printer for theorems.                                                   *)
(* ------------------------------------------------------------------------- *)

open Proven;;

let print_thm th =
  open_box 0;
  print_string "|-"; print_space();
  open_box 0; print_formula print_atom 0 (concl th); close_box();
  close_box();;

(*
#install_printer print_thm;;
*)
(* ========================================================================= *)
(* Propositional reasoning by derived rules atop the LCF core.               *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

let dest_iff fm =
  match fm with
    Iff(p,q) -> (p,q)
  | _ -> failwith "dest_iff: not an equivalence";;

let dest_and fm =
  match fm with
    And(p,q) -> (p,q)
  | _ -> failwith "dest_and: not a conjunction";;

let consequent = snd ** dest_imp;;
let antecedent = fst ** dest_imp;;

(* ------------------------------------------------------------------------- *)
(* If |- q then |- p ==> q                                                   *)
(* ------------------------------------------------------------------------- *)

let add_assum p th = modusponens (axiom_addimp (concl th) p) th;;

(* ------------------------------------------------------------------------- *)
(* If |- q ==> r then |- (p ==> q) ==> (p ==> r)                             *)
(* ------------------------------------------------------------------------- *)

let imp_add_assum p th =
  let (q,r) = dest_imp(concl th) in
  modusponens (axiom_distribimp p q r) (add_assum p th);;

(* ------------------------------------------------------------------------- *)
(* If |- p1 ==> .. ==> pn ==> q & |- q ==> r then |- p1 ==> .. ==> pn ==> r  *)
(* ------------------------------------------------------------------------- *)

let imp_trans =
  let rec break s q =
    if s = q then [] else let p,t = dest_imp s in p::(break t q) in
  fun th1 th2 ->
    let q = antecedent(concl th2) in
    let ps = break (concl th1) q in
    modusponens (itlist imp_add_assum ps th2) th1;;

(* ------------------------------------------------------------------------- *)
(* If |- p ==> r then |- p ==> q ==> r                                       *)
(* ------------------------------------------------------------------------- *)

let imp_insert q th =
  let (p,r) = dest_imp(concl th) in
  imp_trans th (axiom_addimp r q);;

(* ------------------------------------------------------------------------- *)
(* If |- p ==> q ==> r then |- q ==> p ==> r                                 *)
(* ------------------------------------------------------------------------- *)

let imp_swap th =
  let p,qr = dest_imp(concl th) in
  let q,r = dest_imp qr in
  imp_trans (axiom_addimp q p)
            (modusponens (axiom_distribimp p q r) th);;

(* ------------------------------------------------------------------------- *)
(* |- p ==> p                                                                *)
(* ------------------------------------------------------------------------- *)

let imp_refl p =
  modusponens (modusponens (axiom_distribimp p (Imp(p,p)) p)
                           (axiom_addimp p (Imp(p,p))))
              (axiom_addimp p p);;

(* ------------------------------------------------------------------------- *)
(* |- (q ==> r) ==> (p ==> q) ==> (p ==> r)                                  *)
(* ------------------------------------------------------------------------- *)

let imp_trans_th p q r =
   imp_trans (axiom_addimp (Imp(q,r)) p)
             (axiom_distribimp p q r);;

(* ------------------------------------------------------------------------- *)
(* If |- p ==> q then |- (q ==> r) ==> (p ==> r)                             *)
(* ------------------------------------------------------------------------- *)

let imp_add_concl r th =
  let (p,q) = dest_imp(concl th) in
  modusponens (imp_swap(imp_trans_th p q r)) th;;

(* ------------------------------------------------------------------------- *)
(* |- (p ==> q ==> r) ==> (q ==> p ==> r)                                    *)
(* ------------------------------------------------------------------------- *)

let imp_swap_th p q r =
  imp_trans (axiom_distribimp p q r)
            (imp_add_concl (Imp(p,r)) (axiom_addimp q p));;

(* ------------------------------------------------------------------------- *)
(* Mappings between |- p <=> q, |- p ==> q and |- q ==> p                    *)
(* ------------------------------------------------------------------------- *)

let iff_imp1 th =
  let (p,q) = dest_iff(concl th) in
  modusponens (axiom_iffimp1 p q) th;;

let iff_imp2 th =
  let (p,q) = dest_iff(concl th) in
  modusponens (axiom_iffimp2 p q) th;;

let imp_antisym th1 th2 =
  let (p,q) = dest_imp(concl th1) in
  modusponens (modusponens (axiom_impiff p q) th1) th2;;

(* ------------------------------------------------------------------------- *)
(* |- p ==> q ==> p /\ q                                                     *)
(* ------------------------------------------------------------------------- *)

let and_pair p q =
  let th1 = iff_imp2(axiom_and p q)
  and th2 = imp_swap_th (Imp(p,Imp(q,False))) q False in
  let th3 = imp_add_assum p (imp_trans th2 th1) in
  modusponens th3 (imp_swap (imp_refl (Imp(p,Imp(q,False)))));;

(* ------------------------------------------------------------------------- *)
(* |- p /\ q ==> p                                                           *)
(* ------------------------------------------------------------------------- *)

let and_left p q =
  let th1 = imp_add_assum p (axiom_addimp False q) in
  let th2 = imp_trans (imp_add_concl False th1) (axiom_doubleneg p) in
  imp_trans (iff_imp1(axiom_and p q)) th2;;

(* ------------------------------------------------------------------------- *)
(* |- p /\ q ==> q                                                           *)
(* ------------------------------------------------------------------------- *)

let and_right p q =
  let th1 = axiom_addimp (Imp(q,False)) p in
  let th2 = imp_trans (imp_add_concl False th1) (axiom_doubleneg q) in
  imp_trans (iff_imp1(axiom_and p q)) th2;;

(* ------------------------------------------------------------------------- *)
(* |- p1 /\ ... /\ pn ==> pi for each 1 <= i <= n (input term right assoc)   *)
(* ------------------------------------------------------------------------- *)

let rec conjths fm =
  try let p,q = dest_and fm in
      (and_left p q)::map (imp_trans (and_right p q)) (conjths q)
  with Failure _ -> [imp_refl fm];;

(* ------------------------------------------------------------------------- *)
(* |- false ==> p                                                            *)
(* ------------------------------------------------------------------------- *)

let ex_falso p =
  imp_trans (axiom_addimp False (Imp(p,False))) (axiom_doubleneg p);;

(* ------------------------------------------------------------------------- *)
(* |- (q ==> false) ==> p ==> (p ==> q) ==> false                            *)
(* ------------------------------------------------------------------------- *)

let imp_truefalse p q =
  imp_trans (imp_trans_th p q False) (imp_swap_th (Imp(p,q)) p False);;

(* ------------------------------------------------------------------------- *)
(* |- true                                                                   *)
(* ------------------------------------------------------------------------- *)

let truth = modusponens (iff_imp2 axiom_true) (imp_refl False);;

(* ------------------------------------------------------------------------- *)
(* If |- p ==> p ==> q then |- p ==> q                                       *)
(* ------------------------------------------------------------------------- *)

let imp_unduplicate th =
  let p,pq = dest_imp(concl th) in
  let q = consequent pq in
  modusponens (modusponens (axiom_distribimp p p q) th) (imp_refl p);;

(* ------------------------------------------------------------------------- *)
(* If |- p ==> qi for 1<=i<=n and |- q1 ==> ... ==> qn ==> r then |- p ==> r *)
(* ------------------------------------------------------------------------- *)

let imp_trans_chain ths th =
  itlist (fun a b -> imp_unduplicate (imp_trans a (imp_swap b)))
         (rev(tl ths)) (imp_trans (hd ths) th);;

(* ------------------------------------------------------------------------- *)
(* |- (p <=> q) <=> (p ==> q) /\ (q ==> p)                                   *)
(* ------------------------------------------------------------------------- *)

let iff_expand p q =
  let pq = Imp(p,q) and qp = Imp(q,p) in
  let th1 = and_pair pq qp and th2 = axiom_impiff p q in
  imp_antisym
   (imp_trans_chain [axiom_iffimp1 p q; axiom_iffimp2 p q] th1)
   (imp_trans_chain [and_left pq qp; and_right pq qp] th2);;

(* ------------------------------------------------------------------------- *)
(* Recursively evaluate expression.                                          *)
(* ------------------------------------------------------------------------- *)

let rec peval cnj ths fmp =
  match fmp with
    False,False -> add_assum cnj (imp_refl False)
  | True,True -> add_assum cnj truth
  | Imp(p0,q0),Imp(p,q) ->
        let pth = peval cnj ths (p0,p)
        and qth = peval cnj ths (q0,q) in
        if consequent(concl qth) = q then
          imp_insert p qth
        else if consequent(concl pth) = Imp(p,False) then
          imp_trans pth (ex_falso q)
        else
          let th1 = imp_trans qth (imp_truefalse p q) in
          let th2 = axiom_distribimp cnj p (Imp(Imp(p,q),False)) in
          modusponens (modusponens th2 th1) pth
  | Not(p0),Not(p) ->
        repeval cnj ths (axiom_not p) (axiom_not p0)
  | Or(p0,q0),Or(p,q) ->
        repeval cnj ths (axiom_or p q) (axiom_or p0 q0)
  | And(p0,q0),And(p,q) ->
        repeval cnj ths (axiom_and p q) (axiom_and p0 q0)
  | Iff(p0,q0),Iff(p,q) ->
        repeval cnj ths (iff_expand p q) (iff_expand p0 q0)
  | _,fm ->
        try find (fun th -> consequent(concl th) = fm) ths
        with Failure _ -> try
            find (fun th -> consequent(concl th) = Imp(fm,False)) ths
        with Failure _ -> failwith "no assignment for atom"

and repeval cnj ths th th0 =
  let (old,nw) = dest_iff(concl th)
  and nw0 = snd(dest_iff(concl th0)) in
  let eth = peval cnj ths (nw0,nw) in
  if consequent(concl eth) = nw then imp_trans eth (iff_imp2 th)
  else imp_trans eth (imp_add_concl False (iff_imp1 th));;

(* ------------------------------------------------------------------------- *)
(* If |- p /\ q ==> r then |- p ==> q ==> r                                  *)
(* ------------------------------------------------------------------------- *)

let shunt th =
  let p,q = dest_and(antecedent(concl th)) in
  modusponens (itlist imp_add_assum [p;q] th) (and_pair p q);;

(* ------------------------------------------------------------------------- *)
(* If |- (p ==> false) ==> p then |- p                                       *)
(* ------------------------------------------------------------------------- *)

let contrad th =
  let p = consequent(concl th) in
  let p' = Imp(p,False) in
  let th1 = modusponens (axiom_distribimp p' p False) (imp_refl p') in
  modusponens (axiom_doubleneg p) (modusponens th1 th);;

(* ------------------------------------------------------------------------- *)
(* If |- p ==> q and |- (p ==> false) ==> q then |- q                        *)
(* ------------------------------------------------------------------------- *)

let bool_cases th1 th2 =
  contrad(imp_trans (imp_add_concl False th1) th2);;

(* ------------------------------------------------------------------------- *)
(* Collect the atoms (including quantified subformulas).                     *)
(* ------------------------------------------------------------------------- *)

let rec patoms fmp =
  match fmp with
    False,False -> []
  | True,True -> []
  | Not(p0),Not(p) -> patoms (p0,p)
  | And(p0,q0),And(p,q) -> union (patoms (p0,p)) (patoms (q0,q))
  | Or(p0,q0),Or(p,q) -> union (patoms (p0,p)) (patoms (q0,q))
  | Imp(p0,q0),Imp(p,q) -> union (patoms (p0,p)) (patoms (q0,q))
  | Iff(p0,q0),Iff(p,q) -> union (patoms (p0,p)) (patoms (q0,q))
  | _,fm -> [fm];;

(* ------------------------------------------------------------------------- *)
(* Prove tautology using pattern term to identify atoms.                     *)
(*                                                                           *)
(* Essentially implements Kalmar's completeness proof (in Mendelson's book). *)
(* ------------------------------------------------------------------------- *)

let lcfptaut =
  let rec splt ats asm fmp =
    match ats with
      [] -> peval asm (conjths asm) fmp
    | a::oats ->
          bool_cases (shunt(splt oats (And(a,asm)) fmp))
                     (shunt(splt oats (And(Imp(a,False),asm)) fmp)) in
  fun pat fm ->
    let fmp = (pat,fm) in
    let th = modusponens (splt (patoms fmp) True fmp) truth in
    if concl th = fm then th else failwith "lcftaut";;

(* ------------------------------------------------------------------------- *)
(* Simple case using formula itself as a pattern.                            *)
(* ------------------------------------------------------------------------- *)

let lcftaut fm = lcfptaut fm fm;;

(* ------------------------------------------------------------------------- *)
(* Examples.                                                                 *)
(* ------------------------------------------------------------------------- *)

(*
lcftaut <<(p ==> q) \/ (q ==> p)>>;;

lcftaut <<p /\ q <=> ((p <=> q) <=> p \/ q)>>;;

lcftaut <<((p ==> q) ==> p) ==> p>>;;

(* ------------------------------------------------------------------------- *)
(* Indication of why we sometimes need lcfptaut                              *)
(* ------------------------------------------------------------------------- *)

let fm = let p = <<a /\ b /\ c /\ d /\ e /\ f /\ g>> in Imp(p,p);;
lcftaut fm;;
lcfptaut <<p ==> p>> fm;;

(* ------------------------------------------------------------------------- *)
(* More examples/tests.                                                      *)
(* ------------------------------------------------------------------------- *)

time lcftaut <<true>>;;

time lcftaut <<false ==> (false ==> false)>>;;

time lcftaut <<p ==> p>>;;

time lcftaut <<(p ==> q) \/ (q ==> p)>>;;

time lcftaut <<(p ==> ~p) \/ p>>;;

time lcftaut <<(p <=> q) <=> (q <=> p)>>;;

time lcftaut <<p \/ (q <=> r) <=> (p \/ q <=> p \/ r)>>;;

time lcftaut <<p /\ q <=> ((p <=> q) <=> p \/ q)>>;;

time lcftaut <<p \/ (q <=> r) <=> (p \/ q <=> p \/ r)>>;;

time lcftaut <<(p ==> q) ==> (q ==> r) ==> p ==> r>>;;

time lcftaut <<((p ==> q) ==> p) ==> p>>;;

time lcftaut <<((p ==> q) ==> q) ==> (p ==> false) ==> q>>;;

time lcftaut <<((p ==> q) ==> false) ==> q ==> p>>;;

time lcftaut <<(p ==> p ==> q) ==> p ==> q>>;;

time lcftaut <<((p ==> q) ==> q) ==> (q ==> false) ==> p>>;;

time lcftaut
 <<(p ==> p) ==> (p ==> q ==> q ==> r ==> s ==> t ==> p)>>;;
*)
(* ========================================================================= *)
(* First order reasoning by derived rules atop the LCF core.                 *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

(* ---------------------------------------------------------------------- *)
(* Symmetry and transitivity of equality.                                 *)
(* ---------------------------------------------------------------------- *)

let eq_sym s t =
  let rth = axiom_eqrefl s in
  funpow 2 (fun th -> modusponens (imp_swap th) rth)
           (axiom_predcong "=" [s; s] [t; s]);;

let eq_trans s t u =
  let th1 = axiom_predcong "=" [t; u] [s; u] in
  let th2 = modusponens (imp_swap th1) (axiom_eqrefl u) in
  imp_trans (eq_sym s t) th2;;

(* ------------------------------------------------------------------------- *)
(* Congruences.                                                              *)
(* ------------------------------------------------------------------------- *)

let rec congruence s t tm =
  if s = tm then imp_refl(mk_eq s t)
  else if not (occurs_in s tm)
  then add_assum (mk_eq s t) (axiom_eqrefl tm) else
  let (Fn(f,args)) = tm in
  let ths = map (congruence s t) args in
  let tms = map (consequent ** concl) ths in
  imp_trans_chain ths (axiom_funcong f (map lhs tms) (map rhs tms));;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)

(*
congruence <<|s|>>  <<|t|>>
           <<|f(s,g(s,t,s),u,h(h(s)))|>> ;;
*)

(* ------------------------------------------------------------------------- *)
(* If |- p ==> q then |- ~q ==> ~p                                           *)
(* ------------------------------------------------------------------------- *)

let contrapos th =
  let p,q = dest_imp(concl th) in
  imp_trans (imp_trans (iff_imp1(axiom_not q)) (imp_add_concl False th))
            (iff_imp2(axiom_not p));;

(* ------------------------------------------------------------------------- *)
(* |- ~ ~p ==> p                                                             *)
(* ------------------------------------------------------------------------- *)

let neg_neg p =
  let th1 = iff_imp1(axiom_not (Not p))
  and th2 = imp_add_concl False (iff_imp2(axiom_not p)) in
  imp_trans (imp_trans th1 th2) (axiom_doubleneg p);;

(* ------------------------------------------------------------------------- *)
(* If |- p ==> q then |- (forall x. p) ==> (forall x. q)                     *)
(* ------------------------------------------------------------------------- *)

let genimp x th =
  let p,q = dest_imp(concl th) in
  modusponens (axiom_allimp x p q) (gen x th);;

(* ------------------------------------------------------------------------- *)
(* If |- p ==> q then |- (exists x. p) ==> (exists x. q)                     *)
(* ------------------------------------------------------------------------- *)

let eximp x th =
  let p,q = dest_imp(concl th) in
  let th1 = contrapos(genimp x (contrapos th)) in
  end_itlist imp_trans
   [iff_imp1(axiom_exists x p); th1; iff_imp2(axiom_exists x q)];;

(* ------------------------------------------------------------------------- *)
(* If |- p ==> q[x] then |- p ==> forall x. q[x]                             *)
(* ------------------------------------------------------------------------- *)

let gen_right x th =
  let p,q = dest_imp(concl th) in
  let th1 = axiom_allimp x p q in
  let th2 = modusponens th1 (gen x th) in
  imp_trans (axiom_impall x p) th2;;

(* ------------------------------------------------------------------------- *)
(* If |- p(x) ==> q then |- (exists x. p(x)) ==> q                           *)
(* ------------------------------------------------------------------------- *)

let exists_left x th =
  let th1 = contrapos(gen_right x (contrapos th))
  and p,q = dest_imp(concl th) in
  let th2 = imp_trans (iff_imp1(axiom_exists x p)) th1 in
  imp_trans th2 (neg_neg q);;

(* ------------------------------------------------------------------------- *)
(* If |- exists x. p(x) ==> q then |- (forall x. p(x)) ==> q                 *)
(* ------------------------------------------------------------------------- *)

let exists_imp th =
  match concl th with
    Exists(x,(Imp(p,q) as pq)) ->
      let xpq = Forall(x,Not pq) and q' = Imp(q,False) in
      let th0 = iff_imp2(axiom_not pq) in
      let th1 = gen_right x (imp_trans (imp_truefalse p q) th0) in
      let th2 = imp_trans th1 (axiom_allimp x p (Not pq)) in
      let th3 = modusponens (iff_imp1(axiom_exists x pq)) th in
      let th4 = imp_trans th3 (iff_imp1(axiom_not xpq)) in
      let th5 = modusponens (imp_trans_th q' xpq False) th4 in
      imp_trans (imp_trans (imp_swap th2) th5) (axiom_doubleneg q)
  | _ -> failwith "exists_imp: wrong sort of theorem";;

(* ------------------------------------------------------------------------- *)
(* Equivalence properties of logical equivalence.                            *)
(* ------------------------------------------------------------------------- *)

let iff_refl p = let th = imp_refl p in imp_antisym th th;;

let iff_sym th =
  let p,q = dest_iff(concl th) in
  let th1 = modusponens (axiom_iffimp1 p q) th
  and th2 = modusponens (axiom_iffimp2 p q) th in
  modusponens (modusponens (axiom_impiff q p) th2) th1;;

let iff_trans_th =
  let pfn = lcfptaut <<(p <=> q) ==> (q <=> r) ==> (p <=> r)>> in
  fun p q r -> pfn(Imp(Iff(p,q),Imp(Iff(q,r),Iff(p,r))));;

let iff_trans th1 th2 =
  let p,q = dest_iff(concl th1) in
  let q,r = dest_iff(concl th2) in
  modusponens (modusponens (iff_trans_th p q r) th1) th2;;

(* ------------------------------------------------------------------------- *)
(* Congruence properties of the propositional connectives.                   *)
(* ------------------------------------------------------------------------- *)

let cong_not =
  let pfn = lcfptaut <<(p <=> p') ==> (~p <=> ~p')>> in
  fun p p' -> pfn(Imp(Iff(p,p'),Iff(Not p,Not p')));;

let cong_bin =
  let ap = <<p>>  and ap' = <<p'>>
  and aq = <<q>>  and aq' = <<q'>>
  and app' = <<p <=> p'>>  and aqq' = <<q <=> q'>>  in
  fun c p p' q q' ->
    let pat = Imp(app',Imp(aqq',Iff(c(ap,aq),c(ap',aq')))) in
    lcfptaut pat (Imp(Iff(p,p'),Imp(Iff(q,q'),Iff(c(p,q),c(p',q')))));;

(* ------------------------------------------------------------------------- *)
(* |- (forall x. P(x) <=> Q(x)) ==> ((forall x. P(x)) <=> (forall x. Q(x)))  *)
(* ------------------------------------------------------------------------- *)

let forall_iff x p q =
  imp_trans_chain
    [imp_trans (genimp x (axiom_iffimp1 p q)) (axiom_allimp x p q);
     imp_trans (genimp x (axiom_iffimp2 p q)) (axiom_allimp x q p)]
    (axiom_impiff (Forall(x,p)) (Forall(x,q)));;

(* ------------------------------------------------------------------------- *)
(* |- (forall x. P(x) <=> Q(x)) ==> ((exists x. P(x)) <=> (exists x. Q(x)))  *)
(* ------------------------------------------------------------------------- *)

let exists_iff x p q =
  let th1 = genimp x (cong_not p q) in
  let th2 = imp_trans th1 (forall_iff x (Not p) (Not q)) in
  let xnp = Forall(x,Not p) and xnq = Forall(x,Not q) in
  let th3 = imp_trans th2 (cong_not xnp xnq) in
  let th4 = iff_trans_th (Exists(x,p)) (Not xnp) (Not xnq) in
  let th5 = imp_trans th3 (modusponens th4 (axiom_exists x p)) in
  let th6 = iff_trans_th (Exists(x,p)) (Not xnq) (Exists(x,q)) in
  let th7 = modusponens (imp_swap th6) (iff_sym(axiom_exists x q)) in
  imp_trans th5 th7;;

(* ------------------------------------------------------------------------- *)
(* Substitution...                                                           *)
(* ------------------------------------------------------------------------- *)

let rec isubst s t fm =
  if not (free_in s fm) then add_assum (mk_eq s t) (iff_refl fm) else
  match fm with
    Atom(R(p,args)) ->
      if args = [] then add_assum (mk_eq s t) (iff_refl fm) else
      let ths = map (congruence s t) args in
      let lts,rts = unzip (map (dest_eq ** consequent ** concl) ths) in
      let ths' = map2 imp_trans ths (map2 eq_sym lts rts) in
      let th = imp_trans_chain ths (axiom_predcong p lts rts)
      and th' = imp_trans_chain ths' (axiom_predcong p rts lts) in
      let fm' = consequent(consequent(concl th)) in
      imp_trans_chain [th; th'] (axiom_impiff fm fm')
  | Not(p) ->
      let th = isubst s t p in
      let p' = snd(dest_iff(consequent(concl th))) in
      imp_trans th (cong_not p p')
  | And(p,q) -> isubst_binary (fun (p,q) -> And(p,q)) s t p q
  | Or(p,q) -> isubst_binary (fun (p,q) -> Or(p,q)) s t p q
  | Imp(p,q) -> isubst_binary (fun (p,q) -> Imp(p,q)) s t p q
  | Iff(p,q) -> isubst_binary (fun (p,q) -> Iff(p,q)) s t p q
  | Forall(x,p) ->
      if mem x (fvt t) then
         let z = variant x (union (fvt t) (fv p)) in
         let th1 = alpha z fm in
         let fm' = consequent(concl th1) in
         let th2 = imp_antisym th1 (alpha x fm')
         and th3 = isubst s t fm' in
         let fm'' = snd(dest_iff(consequent(concl th3))) in
         imp_trans th3 (modusponens (iff_trans_th fm fm' fm'') th2)
      else
         let th = isubst s t p in
         let p' = snd(dest_iff(consequent(concl th))) in
         imp_trans (gen_right x th) (forall_iff x p p')
  | Exists(x,p) ->
      let th0 = axiom_exists x p in
      let th1 = isubst s t (snd(dest_iff(concl th0))) in
      let Imp(_,Iff(fm',(Not(Forall(y,Not(p'))) as q))) = concl th1 in
      let th2 = imp_trans th1 (modusponens (iff_trans_th fm fm' q) th0)
      and th3 = iff_sym(axiom_exists y p') in
      let r = snd(dest_iff(concl th3)) in
      imp_trans th2 (modusponens (imp_swap(iff_trans_th fm q r)) th3)
  | _ -> add_assum (mk_eq s t) (iff_refl fm)

and isubst_binary cons s t p q =
  let th_p = isubst s t p and th_q = isubst s t q in
  let p' = snd(dest_iff(consequent(concl th_p)))
  and q' = snd(dest_iff(consequent(concl th_q))) in
  let th1 = imp_trans th_p (cong_bin cons p p' q q') in
  imp_unduplicate (imp_trans th_q (imp_swap th1))

(* ------------------------------------------------------------------------- *)
(* ...specialization...                                                      *)
(* ------------------------------------------------------------------------- *)

and ispec t fm =
  match fm with
    Forall(x,p) ->
      if mem x (fvt t) then
        let th1 = alpha (variant x (union (fvt t) (fv p))) fm in
        imp_trans th1 (ispec t (consequent(concl th1)))
      else
        let th1 = isubst (Var x) t p in
        let eq,bod = dest_imp(concl th1) in
        let p' = snd(dest_iff bod) in
        let th2 = imp_trans th1 (axiom_iffimp1 p p') in
        exists_imp(modusponens (eximp x th2) (axiom_existseq x t))
  | _ -> failwith "ispec: non-universal formula"

(* ------------------------------------------------------------------------- *)
(* ...and renaming, all mutually recursive.                                  *)
(* ------------------------------------------------------------------------- *)

and alpha z fm =
  let th1 = ispec (Var z) fm in
  let ant,cons = dest_imp(concl th1) in
  let th2 = modusponens (axiom_allimp z ant cons) (gen z th1) in
  imp_trans (axiom_impall z fm) th2;;

(* ------------------------------------------------------------------------- *)
(* Specialization rule.                                                      *)
(* ------------------------------------------------------------------------- *)

let spec t th = modusponens (ispec t (concl th)) th;;

(* ------------------------------------------------------------------------- *)
(* Tests.                                                                    *)
(* ------------------------------------------------------------------------- *)

(*
isubst <<|x + x|>> <<|2 * x|>> <<x + x = x + x + x>> ;;

ispec <<|y|>> <<forall x y z. x + y + z = z + y + x>> ;;

isubst <<|x + x|>> <<|2 * x|>> <<x + x = x ==> x = 0>> ;;

isubst <<|x + x|>>  <<|2 * x|>>
       <<(x + x = y + y) ==> (y + y + y = x + x + x)>> ;;

ispec <<|x|>> <<forall x y z. x + y + z = y + z + z>> ;;

ispec <<|x|>> <<forall x. x = x>> ;;

ispec <<|w + y + z|>> <<forall x y z. x + y + z = y + z + z>> ;;

ispec <<|x + y + z|>> <<forall x y z. x + y + z = y + z + z>> ;;

ispec <<|x + y + z|>> <<forall x y z. nothing_much>> ;;

isubst <<|x + x|>> <<|2 * x|>>
       <<(x + x = y + y) <=> (something \/ y + y + y = x + x + x)>> ;;

isubst <<|x + x|>>  <<|2 * x|>>
       <<(exists x. x = 2) <=> exists y. y + x + x = y + y + y>> ;;

isubst <<|x|>>  <<|y|>>
       <<(forall z. x = z) <=> (exists x. y < z) /\ (forall y. y < x)>> ;;
*)
(* ========================================================================= *)
(* Storing proof logs for tableaux and constructing LCF proofs from them.    *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

(* ------------------------------------------------------------------------- *)
(* Conversionals.                                                            *)
(* ------------------------------------------------------------------------- *)

let then_conv conv1 conv2 fm =
  let th1 = conv1 fm in
  let fm1 = snd(dest_iff(concl th1)) in
  iff_trans th1 (conv2 fm1);;

let rec sub_conv conv fm =
  match fm with
  | Not(p) ->
        let pth = conv p in
        let p' = snd(dest_iff(concl pth)) in
        modusponens (cong_not p p') pth
  | And(p,q) -> binconv conv (fun (p,q) -> And(p,q)) (p,q)
  | Or(p,q) -> binconv conv (fun (p,q) -> Or(p,q)) (p,q)
  | Imp(p,q) -> binconv conv (fun (p,q) -> Imp(p,q)) (p,q)
  | Iff(p,q) -> binconv conv (fun (p,q) -> Iff(p,q)) (p,q)
  | Forall(x,p) -> quantconv conv forall_iff (x,p)
  | Exists(x,p) -> quantconv conv exists_iff (x,p)
  | _ -> iff_refl fm

and binconv conv cons (p,q) =
  let pth = conv p and qth = conv q in
  let p' = snd(dest_iff(concl pth)) and q' = snd(dest_iff(concl qth)) in
  let th = cong_bin cons p p' q q' in
  modusponens (modusponens th pth) qth

and quantconv conv crule (x,p) =
  let pth = conv p in
  let p' = snd(dest_iff(concl pth)) in
  let th = crule x p p' in
  modusponens th (gen x pth);;

(* ------------------------------------------------------------------------- *)
(* Depth conversions.                                                        *)
(* ------------------------------------------------------------------------- *)

let rec single_depth_conv conv fm =
  (then_conv (sub_conv (single_depth_conv conv)) conv) fm;;

let rec top_depth_conv conv fm =
  try then_conv conv (top_depth_conv conv) fm
  with Failure _ -> sub_conv (top_depth_conv conv) fm;;

(* ------------------------------------------------------------------------- *)
(* Aid to tautology-based simplification.                                    *)
(* ------------------------------------------------------------------------- *)

let tsimp fm fm' pat pat' = lcfptaut (Iff(pat,pat')) (Iff(fm,fm'));;

(* ------------------------------------------------------------------------- *)
(* Simplification, once and at depth, by proof.                              *)
(* ------------------------------------------------------------------------- *)

let forall_triv x p =
  imp_antisym (ispec (Var x) (Forall(x,p))) (axiom_impall x p);;

let exists_triv =
  let pfn = lcfptaut <<(p <=> ~q) ==> (q <=> ~r) ==> (p <=> r)>> in
  fun x p ->
  let th = pfn
   (Imp(Iff(Exists(x,p),Not(Forall(x,Not p))),
        Imp(Iff(Forall(x,Not p),Not p),Iff(Exists(x,p),p)))) in
  modusponens (modusponens th (axiom_exists x p))
              (forall_triv x (Not p));;

let simplify1_conv =
  let a = Atom(R("dummy",[])) in
  fun fm ->
    match fm with
      Not False -> tsimp fm True (Not False) True
    | Not True -> tsimp fm False (Not True) False
    | And(False,q) -> tsimp fm False (And(False,a)) False
    | And(p,False) -> tsimp fm False (And(a,False)) False
    | And(True,q) -> tsimp fm q (And(True,a)) a
    | And(p,True) -> tsimp fm p (And(a,True)) a
    | Or(False,q) -> tsimp fm q (Or(False,a)) a
    | Or(p,False) -> tsimp fm p (Or(a,False)) a
    | Or(True,q) -> tsimp fm True (Or(True,a)) True
    | Or(p,True) -> tsimp fm True (Or(a,True)) True
    | Imp(False,q) -> tsimp fm True (Imp(False,a)) True
    | Imp(True,q) -> tsimp fm q (Imp(True,a)) a
    | Imp(p,True) -> tsimp fm True (Imp(a,True)) True
    | Imp(p,False) -> tsimp fm (Not p) (Imp(a,False)) (Not a)
    | Iff(True,q) -> tsimp fm q (Iff(True,a)) a
    | Iff(p,True) -> tsimp fm p (Iff(a,True)) a
    | Iff(False,q) -> tsimp fm (Not q) (Iff(False,a)) (Not a)
    | Iff(p,False) -> tsimp fm (Not p) (Iff(a,False)) (Not a)
    | Forall(x,p) ->
           if mem x (fv p) then iff_refl fm else forall_triv x p
    | Exists(x,p) ->
          if mem x (fv p) then iff_refl fm else exists_triv x p
    | _ -> iff_refl fm;;

let simplify_conv = single_depth_conv simplify1_conv;;

(* ------------------------------------------------------------------------- *)
(* Examples.                                                                 *)
(* ------------------------------------------------------------------------- *)

let fm = <<forall x y. (P(x) /\ false) \/ (true ==> Q(y))>>;;

(*
simplify_conv fm;;

simplify fm;;
*)

(* ------------------------------------------------------------------------- *)
(* Negation normal form by proof.                                            *)
(* ------------------------------------------------------------------------- *)

let not_exists =
  let pfn = lcfptaut <<(p <=> ~q) ==> (~p <=> q)>> in
  fun x p ->
    modusponens
     (pfn(Imp(Iff(Exists(x,p),Not(Forall(x,Not p))),
                  Iff(Not(Exists(x,p)),Forall(x,Not p)))))
     (axiom_exists x p);;

let not_forall =
  let pfn = lcfptaut <<~(~p) <=> p>> in
  fun x p ->
    let th1 = gen x (pfn(Iff(Not(Not p),p))) in
    let th2 = modusponens (forall_iff x (Not(Not p)) p) th1 in
    let th3 = cong_not (Forall(x,Not(Not p))) (Forall(x,p)) in
    let th4 = modusponens th3 th2 in
    iff_sym(iff_trans (axiom_exists x (Not p)) th4);;

let nnf1_conv =
  let a = Atom(R("dummy",[])) in
  fun fm ->
    match fm with
      Imp(p,q) -> tsimp fm (Or(Not p,q)) (Imp(a,a)) (Or(Not a,a))
    | Iff(p,q) -> tsimp fm (Or(And(p,q),And(Not p,Not q)))
                           (Iff(a,a)) (Or(And(a,a),And(Not a,Not a)))
    | Not(Not p) -> tsimp fm p (Not(Not a)) a
    | Not(And(p,q)) -> tsimp fm (Or(Not p,Not q))
                                (Not(And(a,a))) (Or(Not a,Not a))
    | Not(Or(p,q)) -> tsimp fm (And(Not p,Not q))
                               (Not(Or(a,a))) (And(Not a,Not a))
    | Not(Imp(p,q)) -> tsimp fm (And(p,Not q))
                                (Not(Imp(a,a))) (And(a,Not a))
    | Not(Iff(p,q)) -> tsimp fm (Or(And(p,Not q),And(Not p,q)))
                                (Not(Iff(a,a)))
                                (Or(And(a,Not a),And(Not a,a)))
    | Not(Forall(x,p)) -> not_forall x p
    | Not(Exists(x,p)) -> not_exists x p
    | _ -> failwith "nnf1_conv: no transformation";;

let nnf_conv = top_depth_conv nnf1_conv;;

(* ------------------------------------------------------------------------- *)
(* Example.                                                                  *)
(* ------------------------------------------------------------------------- *)

let fm =
 <<(forall x. P(x)) ==> ((exists y. Q(y)) <=> exists z. P(z) /\ Q(z))>>;;

(*
nnf fm;;

concl (nnf_conv fm) = Iff(fm,nnf fm);;
*)

(* ------------------------------------------------------------------------- *)
(* Proof format for tableaux.                                                *)
(* ------------------------------------------------------------------------- *)

type prooflog = Literal of int
              | Requeue
              | Univ of term;;

(* ------------------------------------------------------------------------- *)
(* Dummy ground term.                                                        *)
(* ------------------------------------------------------------------------- *)

let dummy = Fn("_Ground",[]);;

let rec ground tm =
  match tm with
    Var x -> dummy
  | Fn(f,args) -> Fn(f,map ground args);;

let startcont (env,k) =
  itlist (fun (x,t) -> (x|->ground t)) (funset(solve env))
         (itlist (fun i -> ("_"^string_of_int i) |-> dummy)
                 (0--k) undefined),[];;

(* ------------------------------------------------------------------------- *)
(* Tableau procedure with proof logging.                                     *)
(* ------------------------------------------------------------------------- *)

let logstep pstep (sfn,prf) = (sfn,pstep::prf);;

let logforall y (sfn,prf) = (sfn,Univ(tryapplyd sfn y (Var y))::prf);;

let rec tableau (fms,lits,n) cont (env,k) =
  if n < 0 then failwith "no proof at this level" else
  match fms with
    [] -> failwith "tableau: no proof"
  | And(p,q)::unexp ->
      tableau (p::q::unexp,lits,n) cont (env,k)
  | Or(p,q)::unexp ->
      tableau (p::unexp,lits,n) (tableau (q::unexp,lits,n) cont) (env,k)
  | Forall(x,p)::unexp ->
      let y = "_" ^ string_of_int k in
      let p' = formsubst (x := Var y) p in
      logforall y
       (tableau (p'::unexp@[Forall(x,p)],lits,n-1) cont (env,k+1))
  | fm::unexp ->
      try tryfind
           (fun l -> logstep (Literal(index l lits))
                             (cont(unify_complements env (fm,l),k)))
           lits
      with Failure _ ->
          logstep Requeue (tableau (unexp,fm::lits,n) cont (env,k));;

let tabrefute_log fms =
  deepen (fun n -> tableau (fms,[],n) startcont (undefined,0)) 0;;

(* ------------------------------------------------------------------------- *)
(* A trivial example.                                                        *)
(* ------------------------------------------------------------------------- *)

(*
tabrefute_log
  [<<(forall x. ~P(x) \/ P(f(x))) /\ P(1) /\ ~P(f(1))>>];;
*)

(* ------------------------------------------------------------------------- *)
(* |- p ==> -p ==> false (p may be negated).                                 *)
(* ------------------------------------------------------------------------- *)

let imp_contrad p =
  if negative p then iff_imp1 (axiom_not (negate p))
  else imp_swap (iff_imp1 (axiom_not p));;

(* ------------------------------------------------------------------------- *)
(* If |- p ==> q ==> r then |- p /\ q ==> r                                  *)
(* ------------------------------------------------------------------------- *)

let ante_conj th =
  let p,qr = dest_imp(concl th) in
  let q,r = dest_imp qr in
  imp_trans_chain [and_left p q; and_right p q] th;;

(* ------------------------------------------------------------------------- *)
(* If |- p ==> r and |- q ==> r then |- p \/ q ==> r                         *)
(* ------------------------------------------------------------------------- *)

let ante_disj th1 th2 =
  let p,r = dest_imp(concl th1)
  and q,s = dest_imp(concl th2) in
  let ths = map contrapos [th1; th2] in
  let th3 = imp_trans_chain ths (and_pair (Not p) (Not q)) in
  let th4 = contrapos(imp_trans (iff_imp2(axiom_not r)) th3) in
  let th5 = imp_trans(iff_imp1(axiom_or p q)) th4 in
  let th6 = imp_trans th5 (iff_imp1 (axiom_not (Imp(r,False)))) in
  imp_trans th6 (axiom_doubleneg r);;

(* ------------------------------------------------------------------------- *)
(* If |- p0 ==> ... ==> pn ==> q then |- pi ==> p0 ==> ..[no pi].. pn ==> q  *)
(* ------------------------------------------------------------------------- *)

let imp_front =
  let rec imp_front_th n fm =
    if n = 0 then imp_refl fm else
    let p1,pq = dest_imp fm in
    let th1 = imp_add_assum p1 (imp_front_th (n - 1) pq) in
    let (Imp(_,Imp(p,Imp(q,r)))) = concl th1 in
    imp_trans th1 (imp_swap_th p q r) in
  fun n th -> modusponens (imp_front_th n (concl th)) th;;

(* ------------------------------------------------------------------------- *)
(* If   |- (p0 ==> ... ==> pn ==> q)                                         *)
(* then |- (p1 ==> ... ==> p(i-1) ==> p0 ==> pi ==> ... ==> pn ==> q         *)
(* ------------------------------------------------------------------------- *)

let imp_back =
  let rec imp_back_th n fm =
    if n = 0 then imp_refl fm else
    let p0,p1q = dest_imp fm in
    let p1,pq = dest_imp p1q in
    let th1 = imp_swap_th p0 p1 pq in
    let th2 = imp_back_th (n-1) (Imp(p0,pq)) in
    imp_trans th1 (imp_add_assum p1 th2) in
  fun n th -> modusponens (imp_back_th n (concl th)) th;;

(* ------------------------------------------------------------------------- *)
(* If |- (p ==> q) ==> (q ==> r) then |- (p ==> q) ==> (p ==> r)             *)
(* ------------------------------------------------------------------------- *)

let imp_chain_imp th =
  match concl th with
    Imp(Imp(p,q),Imp(q',r)) ->
        imp_unduplicate (imp_trans th (imp_trans_th p q r))
  | _ -> failwith "imp_chain_imp: wrong kind of theorem";;

(* ------------------------------------------------------------------------- *)
(* Hack down Skolem instantiations list for existentials.                    *)
(* ------------------------------------------------------------------------- *)

let rec hack fm l =
  match fm with
    Exists(x,p) -> hack p (tl l)
  | Forall(x,p) -> hack p l
  | And(p,q) -> hack q (hack p l)
  | Or(p,q) -> hack q (hack p l)
  | _ -> l;;

(* ------------------------------------------------------------------------- *)
(* Reconstruct LCF proof from tableaux log, undoing Skolemization.           *)
(* ------------------------------------------------------------------------- *)

let rec reconstruct shyps rfn proof fms lits =
  match (proof,fms) with
    (prf,(Exists(y,p) as fm,skins)::unexp) ->
        let hfm = find (fun h -> antecedent h = fm) (hd skins) in
        let fm' = consequent hfm in
        let th1,prf' =
          reconstruct shyps rfn prf ((fm',tl skins)::unexp) lits in
        let i = length fms + length lits + index hfm shyps in
        imp_back i (imp_chain_imp (imp_front i th1)),prf'
  | (prf,(And(p,q),skins)::unexp) ->
        let th,prf' =
          reconstruct shyps rfn prf
             ((p,skins)::(q,hack p skins)::unexp) lits in
        ante_conj th,prf'
  | (prf,(Or(p,q),skins)::unexp) ->
        let thp,prf' =
          reconstruct shyps rfn prf ((p,skins)::unexp) lits in
        let thq,prf'' =
          reconstruct shyps rfn prf' ((q,hack p skins)::unexp) lits in
        ante_disj thp thq,prf''
  | (Univ(t)::prf,(Forall(x,p),skins)::unexp) ->
        let t' = replacet rfn t in
        let th1 = ispec t' (Forall(x,p)) in
        let th,prf' = reconstruct shyps rfn prf
          ((consequent(concl th1),skins)::
           unexp@[Forall(x,p),skins]) lits in
        imp_unduplicate (imp_front (length fms) (imp_trans th1 th)),prf'
  | (Literal(i)::prf,(fm,_)::unexp) ->
        let th = imp_contrad fm in
        let lits1,lits2 = chop_list i lits in
        let th1 =
          itlist imp_insert (tl lits2 @ shyps) (imp_refl False) in
        let th2 = imp_add_assum (hd lits2) th1 in
        let th3 = itlist imp_insert (map fst unexp @ lits1) th2 in
        modusponens (imp_add_assum fm th3) th,prf
  | (Requeue::prf,(fm,_)::unexp) ->
        let th,prf' =
           reconstruct shyps rfn prf unexp (fm::lits) in
        imp_front (length unexp) th,prf';;

(* ------------------------------------------------------------------------- *)
(* Remove Skolem-type hypotheses from theorem.                               *)
(* ------------------------------------------------------------------------- *)

let skoscrub th =
  match concl th with
    Imp(Imp((Exists(x,q) as p),p'),r) ->
        let [v] = subtract (fv p') (fv p) in
        let th1 = spec (Var x) (gen v th) in
        let th2 = exists_left x (imp_trans (axiom_addimp q p) th1)
        and th3 = imp_trans (imp_add_assum p (ex_falso p')) th in
        bool_cases th2 th3
  | _ -> failwith "skoscrub: no Skolem antecedent";;

(* ------------------------------------------------------------------------- *)
(* "Glass" Skolemization recording correspondences.                          *)
(* ------------------------------------------------------------------------- *)

let gaskolemize fm =
  let corr = map (fun (n,a) -> Fn(n,[]),False) (functions fm) in
  let fm',corr' = skolem fm corr in
  fm',rev(filter (fun x -> not(mem x corr)) corr');;

(* ------------------------------------------------------------------------- *)
(* Just get the existential instances from a proof.                          *)
(* ------------------------------------------------------------------------- *)

let rec exinsts proof fms lits =
  match (proof,fms) with
    (prf,(Exists(y,p),ifn,((t,fm)::osks as sks))::unexp) ->
        let p' = formsubst (y := t) p in
        let e,prf' = exinsts prf ((p',ifn,osks)::unexp) lits in
        insert (termsubst ifn t) e,prf'
  | (Univ(t)::prf,(Forall(x,p),ifn,sks)::unexp) ->
        let ifn' = (x |-> t) ifn in
        exinsts prf ((p,ifn',sks)::unexp@[Forall(x,p),ifn,sks]) lits
  | (prf,(And(p,q),ifn,sks)::unexp) ->
        exinsts prf ((p,ifn,sks)::(q,ifn,hack p sks)::unexp) (fm::lits)
  | (prf,(Or(p,q),ifn,sks)::unexp) ->
        let e1,prf' = exinsts prf ((p,ifn,sks)::unexp) lits in
        let e2,prf'' = exinsts prf' ((q,ifn,hack p sks)::unexp) lits in
        union e1 e2,prf''
  | (Literal(i)::prf,_) ->
        [],prf
  | (Requeue::prf,(fm,_,_)::unexp) ->
        exinsts prf unexp (fm::lits);;

(* ------------------------------------------------------------------------- *)
(* Set up hypotheses for Skolem functions, in left-to-right order.           *)
(* ------------------------------------------------------------------------- *)

let rec skolem_hyps rfn sks skts =
  match sks with
    [] -> []
  | (Fn(f,xs) as st,(Exists(y,q) as fm) as sk)::osks ->
        let sins,oskts = partition (fun (Fn(g,_)) -> g = f) skts in
        let mk_hyp (Fn(g,ts) as ti) =
          let ifn = itlist2 (fun (Var x) t -> x |-> t)
                            (Var y::xs) (ti ::ts)
                            undefined in
          (replace rfn (formsubst ifn (Imp(fm,q)))) in
        map mk_hyp sins :: skolem_hyps rfn osks oskts;;

(* ------------------------------------------------------------------------- *)
(* Sort Skolem hypotheses into wellfounded "term depth" order.               *)
(* ------------------------------------------------------------------------- *)

let rec sortskohyps shyps dun =
  if shyps = [] then rev dun else
  let h = find (fun h -> let p,q = dest_imp h in
                         let [v] = subtract (fv q) (fv p) in
                         not (exists (fun g -> free_in (Var v) g)
                                     (subtract shyps [h])))
               shyps in
  sortskohyps (subtract shyps [h]) (h::dun);;

(* ------------------------------------------------------------------------- *)
(* Overall function.                                                         *)
(* ------------------------------------------------------------------------- *)

let tab_rule fm0 =
  let fvs = fv fm0 in
  let fm1 = itlist (fun x p -> Forall(x,p)) fvs fm0 in
  let thn = iff_imp1((then_conv simplify_conv nnf_conv) (Not fm1)) in
  let fm = consequent(concl thn) in
  let sfm,sks = gaskolemize fm in
  let _,proof = tabrefute_log [sfm] in
  let skts,[] = exinsts proof [fm,undefined,sks] [] in
  let rfn = itlist2 (fun k t -> t |-> Var("_"^string_of_int k))
                     (1 -- length skts) skts undefined in
  let skins = skolem_hyps rfn sks skts in
  let shyps = sortskohyps(itlist (@) skins []) [] in
  let th1,[] = reconstruct shyps rfn proof [fm,skins] [] in
  let th2 = funpow (length shyps) (skoscrub ** imp_swap) th1 in
  let th3 = imp_trans (imp_trans (iff_imp2(axiom_not fm1)) thn) th2 in
  let th4 = modusponens (axiom_doubleneg fm1) th3 in
  itlist (fun x -> spec (Var x)) (rev fvs) th4;;

(* ------------------------------------------------------------------------- *)
(* Examples.                                                                 *)
(* ------------------------------------------------------------------------- *)

(*

let p58 = tab_rule
 <<forall P Q R. forall x. exists v. exists w. forall y. forall z.
   ((P(x) /\ Q(y)) ==> ((P(v) \/ R(w))  /\ (R(z) ==> Q(v))))>>;;

let p26 = time tab_rule
 <<((exists x. P(x)) <=> (exists x. Q(x))) /\
   (forall x y. P(x) /\ Q(y) ==> (R(x) <=> U(y))) ==>
   ((forall x. P(x) ==> R(x)) <=> (forall x. Q(x) ==> U(x)))>>;;

let p28 = time tab_rule
 <<(forall x. P(x) ==> (forall x. Q(x))) /\
   ((forall x. Q(x) \/ R(x)) ==> (exists x. Q(x) /\ R(x))) /\
   ((exists x. R(x)) ==> (forall x. L(x) ==> M(x))) ==>
   (forall x. P(x) /\ L(x) ==> M(x))>>;;

let p33 = time tab_rule
 <<(forall x. P(a) /\ (P(x) ==> P(b)) ==> P(c)) <=>
   (forall x. P(a) ==> P(x) \/ P(c)) /\ (P(a) ==> P(b) ==> P(c))>>;;

let p35 = time tab_rule
 <<exists x y. P(x,y) ==> (forall x y. P(x,y))>>;;

let p38 = time tab_rule
 <<(forall x.
     P(a) /\ (P(x) ==> (exists y. P(y) /\ R(x,y))) ==>
     (exists z w. P(z) /\ R(x,w) /\ R(w,z))) <=>
   (forall x.
     (~P(a) \/ P(x) \/ (exists z w. P(z) /\ R(x,w) /\ R(w,z))) /\
     (~P(a) \/ ~(exists y. P(y) /\ R(x,y)) \/
     (exists z w. P(z) /\ R(x,w) /\ R(w,z))))>>;;

let p45 = time tab_rule
 <<(forall x.
     P(x) /\ (forall y. G(y) /\ H(x,y) ==> J(x,y)) ==>
       (forall y. G(y) /\ H(x,y) ==> R(y))) /\
   ~(exists y. L(y) /\ R(y)) /\
   (exists x. P(x) /\ (forall y. H(x,y) ==>
     L(y)) /\ (forall y. G(y) /\ H(x,y) ==> J(x,y))) ==>
   (exists x. P(x) /\ ~(exists y. G(y) /\ H(x,y)))>>;;

let davis_putnam_example = time tab_rule
 <<exists x. exists y. forall z.
        (F(x,y) ==> (F(y,z) /\ F(z,z))) /\
        ((F(x,y) /\ G(x,y)) ==> (G(x,z) /\ G(z,z)))>>;;

let gilmore_9 = time tab_rule
 <<forall x. exists y. forall z.
        ((forall u. exists v. F(y,u,v) /\ G(y,u) /\ ~H(y,x))
          ==> (forall u. exists v. F(x,u,v) /\ G(z,u) /\ ~H(x,z))
             ==> (forall u. exists v. F(x,u,v) /\ G(y,u) /\ ~H(x,y))) /\
        ((forall u. exists v. F(x,u,v) /\ G(y,u) /\ ~H(x,y))
         ==> ~(forall u. exists v. F(x,u,v) /\ G(z,u) /\ ~H(x,z))
             ==> (forall u. exists v. F(y,u,v) /\ G(y,u) /\ ~H(y,x)) /\
                 (forall u. exists v. F(z,u,v) /\ G(y,u) /\ ~H(z,y)))>>;;

let ewd1062_1 = time tab_rule
 <<(forall x. x <= x) /\
   (forall x y z. x <= y /\ y <= z ==> x <= z) /\
   (forall x y. f(x) <= y <=> x <= g(y))
   ==> (forall x y. x <= y ==> f(x) <= f(y))>>;;

let ewd1062_2 = time tab_rule
 <<(forall x. x <= x) /\
   (forall x y z. x <= y /\ y <= z ==> x <= z) /\
   (forall x y. f(x) <= y <=> x <= g(y))
   ==> (forall x y. x <= y ==> g(x) <= g(y))>>;;

(* ------------------------------------------------------------------------- *)
(* Some further examples.                                                    *)
(* ------------------------------------------------------------------------- *)

let gilmore_3 = time tab_rule
 <<exists x. forall y z.
        ((M(y,z) ==> (G(y) ==> H(x))) ==> M(x,x)) /\
        ((M(z,x) ==> G(x)) ==> H(z)) /\
        M(x,y)
        ==> M(z,z)>>;;

let gilmore_4 = time tab_rule
 <<exists x y. forall z.
        (M(x,y) ==> M(y,z) /\ M(z,z)) /\
        (M(x,y) /\ G(x,y) ==> G(x,z) /\ G(z,z))>>;;

let gilmore_5 = time tab_rule
 <<(forall x. exists y. M(x,y) \/ M(y,x)) /\
   (forall x y. M(y,x) ==> M(y,y))
   ==> exists z. M(z,z)>>;;

let gilmore_6 = time tab_rule
 <<forall x. exists y.
        (exists u. forall v. M(u,x) ==> G(v,u) /\ G(u,x))
        ==> (exists u. forall v. M(u,y) ==> G(v,u) /\ G(u,y)) \/
            (forall u v. exists w. G(v,u) \/ H(w,y,u) ==> G(u,w))>>;;

let gilmore_7 = time tab_rule
 <<(forall x. K(x) ==> exists y. L(y) /\ (M(x,y) ==> G(x,y))) /\
   (exists z. K(z) /\ forall u. L(u) ==> M(z,u))
   ==> exists v w. K(v) /\ L(w) /\ G(v,w)>>;;

let gilmore_8 = time tab_rule
 <<exists x. forall y z.
        ((M(y,z) ==> (G(y) ==> (forall u. exists v. H(u,v,x)))) ==> M(x,x)) /\
        ((M(z,x) ==> G(x)) ==> (forall u. exists v. H(u,v,z))) /\
        M(x,y)
        ==> M(z,z)>>;;

let ewd_1038' = time tab_rule
 <<(forall x y z. x <= y /\ y <= z ==> x <= z) /\
   (forall x y. x < y <=> ~(y <= x))
   ==> (forall x y z. x <= y /\ y < z ==> x < z) /\
       (forall x y z. x < y /\ y <= z ==> x < z)>>;;

*)
(* ========================================================================= *)
(* Goals, LCF-like tactics and Mizar-like proofs.                            *)
(*                                                                           *)
(* Copyright (c) 2003, John Harrison. (See "LICENSE.txt" for details.)       *)
(* ========================================================================= *)

type goals =
  Goals of ((string * fol formula) list * fol formula)list *
           (Proven.thm list -> Proven.thm);;

(* ------------------------------------------------------------------------- *)
(* Printer for goals (just shows first goal plus total number).              *)
(* ------------------------------------------------------------------------- *)

let print_goal =
  let print_hyp (l,fm) =
    open_hbox(); print_string(l^":"); print_space();
    print_formula print_atom 0 fm; print_newline(); close_box() in
  fun (Goals(gls,jfn)) ->
    match gls with
      (asl,w)::ogls ->
         print_newline();
         (if ogls = [] then print_string "1 subgoal:" else
          (print_int (length gls);
           print_string " subgoals starting with"));
         print_newline();
         do_list print_hyp (rev asl);
         print_string "---> ";
         open_hvbox 0; print_formula print_atom 0 w; close_box();
         print_newline()
    | [] -> print_string "No subgoals";;

(*
#install_printer print_goal;;
*)

(* ------------------------------------------------------------------------- *)
(* Setting up goals and terminating them in a theorem.                       *)
(* ------------------------------------------------------------------------- *)

let set_goal fm =
  let chk th = if concl th = fm then th else failwith "wrong theorem" in
  Goals([[],fm],fun [th] -> chk(modusponens th truth));;

let extract_thm gls =
  match gls with
    Goals([],jfn) -> jfn []
  | _ -> failwith "extract_thm: unsolved goals";;

(* ------------------------------------------------------------------------- *)
(* Running a series of proof steps one by one on goals.                      *)
(* ------------------------------------------------------------------------- *)

let run prf g = itlist (fun f -> f) (rev prf) g;;

(* ------------------------------------------------------------------------- *)
(* Handy idiom for tactic that does not split subgoals.                      *)
(* ------------------------------------------------------------------------- *)

let jmodify jfn tfn ths =
  match ths with
    (th::oths) -> jfn(tfn th :: oths)
  | _ -> failwith "jmodify: no first theorem";;

(* ------------------------------------------------------------------------- *)
(* Append contextual hypothesis to unconditional theorem.                    *)
(* ------------------------------------------------------------------------- *)

let assumptate (Goals((asl,w)::gls,jfn)) th =
  add_assum (list_conj (map snd asl)) th;;

(* ------------------------------------------------------------------------- *)
(* Turn assumptions p1,...,pn into theorems |- p1 /\ ... /\ pn ==> pi        *)
(* ------------------------------------------------------------------------- *)

let rec assumps asl =
  match asl with
    [] -> []
  | [l,p] -> [l,imp_refl p]
  | (l,p)::lps ->
        let ths = assumps lps in
        let q = antecedent(concl(snd(hd ths))) in
        let rth = and_right p q in
        (l,and_left p q)::map (fun (l,th) -> l,imp_trans rth th) ths;;

let firstassum asl =
  let p = snd(hd asl) and q = list_conj(map snd (tl asl)) in
  if tl asl = [] then imp_refl p else and_left p q;;

(* ------------------------------------------------------------------------- *)
(* Another inference rule: |- P[t] ==> exists x. P[x]                        *)
(* ------------------------------------------------------------------------- *)

let right_exists x t p =
  let th1 = ispec t (Forall(x,Not p)) in
  let Not(p') = consequent(concl th1) in
  let th2 = imp_trans th1 (iff_imp1(axiom_not p')) in
  let th3 = imp_add_concl False th2 in
  let th4 = imp_trans (imp_swap(imp_refl(Imp(p',False)))) th3 in
  let th5 = imp_trans th4 (iff_imp2(axiom_not(Forall(x,Not p)))) in
  imp_trans th5 (iff_imp2(axiom_exists x p));;

(* ------------------------------------------------------------------------- *)
(* Two simple natural deduction constructs.                                  *)
(* ------------------------------------------------------------------------- *)

let fix a (Goals((asl,(Forall(x,p) as fm))::gls,jfn)) =
  if exists (mem a ** fv ** snd) asl
  then failwith "fix: variable free in assumptions" else
  let p' = formsubst(x := Var a) p in
  let jfn' = jmodify jfn
   (fun th -> imp_trans (gen_right a th) (alpha x (Forall(a,p')))) in
   Goals((asl,p')::gls,jfn');;

let take s (Goals((asl,(Exists(x,p) as fm))::gls,jfn)) =
  let t = parset s in
  let p' = formsubst(x := t) p in
  let jfn' = jmodify jfn
   (fun th -> imp_trans th (right_exists x t p)) in
  Goals((asl,p')::gls,jfn');;

(* ------------------------------------------------------------------------- *)
(* Parse a labelled formula, recognizing "thesis" and "antecedent"           *)
(* ------------------------------------------------------------------------- *)

let expand_atom thesis at =
  match at with
    R("antecedent",[]) ->
        (try fst(dest_imp thesis) with Failure _ -> Atom at)
  | R("thesis",[]) -> thesis
  | _ -> Atom at;;

let expand thesis fm = onatoms (expand_atom thesis) fm;;

(* ------------------------------------------------------------------------- *)
(* Restore old version.                                                      *)
(* ------------------------------------------------------------------------- *)

let thesis = "thesis";;

let parself (Goals((asl,w)::gls,jfn)) toks =
  match toks with
   name::":"::toks ->
      let fm,toks' = parse_formula parse_atom [] toks in
      (name,expand w fm),toks'
 | toks -> let fm,toks' = parse_formula parse_atom [] toks in
           ("",expand w fm),toks';;

let rec parselfs g toks =
  let res1,toks' = parself g toks in
  match toks' with
   "and"::toks'' ->
        let ress,toks''' = parselfs g toks'' in res1::ress,toks'''
  | _ -> [res1],toks';;

let parse_labelled_formulas g s =
  let fms,l = parselfs g (lex(explode s)) in
  let fm = end_itlist (fun p q -> And(p,q)) (map snd fms) in
  if l = [] then fms,fm
  else failwith "parse_labelled_formulas: unparsed input";;

let parse_labelled_formula g s =
  match parse_labelled_formulas g s with
    [s,p],p' -> s,p
  | _ -> failwith "too many formulas";;

(* ------------------------------------------------------------------------- *)
(* |- p1 /\ .. /\ pn ==> q to |- pi+1 /\ ... /\ pn ==> p1 /\ .. /\ pi ==> q  *)
(* ------------------------------------------------------------------------- *)

let multishunt i th =
  let th1 = funpow i (imp_swap ** shunt) th in
  let th2 = funpow (i-1) (ante_conj ** imp_front 2) (imp_swap th1) in
  imp_swap th2;;

(* ------------------------------------------------------------------------- *)
(* Add labelled formulas to the assumption list.                             *)
(* ------------------------------------------------------------------------- *)

let assume s (Goals((asl,Imp(p,q))::gls,jfn) as gl) =
  let (lps,p') = parse_labelled_formulas gl s in
  if p <> p' then failwith "assume: doesn't match antecedent" else
  let jfn' = jmodify jfn (fun th ->
    if asl = [] then add_assum True th
                else multishunt (length lps) th) in
  Goals((lps@asl,q)::gls,jfn');;

(* ------------------------------------------------------------------------- *)
(* Delayed version of tableau rule, for speed of first phase.                *)
(* ------------------------------------------------------------------------- *)

let delayed_tab_rule fm0 =
  let fvs = fv fm0 in
  let fm1 = itlist (fun x p -> Forall(x,p)) fvs fm0 in
  let fm = nnf(simplify(Not fm1)) in
  let sfm,sks = gaskolemize fm in
  let _,proof = tabrefute_log [sfm] in
  fun () ->
    let thn = iff_imp1((then_conv simplify_conv nnf_conv) (Not fm1)) in
    let skts,[] = exinsts proof [fm,undefined,sks] [] in
    let rfn = itlist2 (fun k t -> t |-> Var("_"^string_of_int k))
                       (1 -- length skts) skts undefined in
    let skins = skolem_hyps rfn sks skts in
    let shyps = sortskohyps(itlist (@) skins []) [] in
    let th1,[] = reconstruct shyps rfn proof [fm,skins] [] in
    let th2 = funpow (length shyps) (skoscrub ** imp_swap) th1 in
    let th3 = imp_trans (imp_trans (iff_imp2(axiom_not fm1)) thn) th2 in
    let th4 = modusponens (axiom_doubleneg fm1) th3 in
    itlist (fun x -> spec (Var x)) (rev fvs) th4;;

(* ------------------------------------------------------------------------- *)
(* Main automatic justification step.                                        *)
(* ------------------------------------------------------------------------- *)

let justify byfn hyps gl p =
  let ps,ths = byfn hyps p gl in
  if ps = [p] then fun () -> hd(ths()) else
  let fm = itlist (fun a b -> Imp(a,b)) ps p in
  let fn = delayed_tab_rule fm in
  fun () -> if ps = [] then assumptate gl (fn())
            else imp_trans_chain (ths()) (fn());;

(* ------------------------------------------------------------------------- *)
(* Produce canonical theorem from list of theorems or assumption labels.     *)
(* ------------------------------------------------------------------------- *)

let by hyps p (Goals((asl,w)::gls,jfn)) =
  map (fun s -> assoc s asl) hyps,
  fun () -> let ths = assumps asl in
            map (fun s -> assoc s ths) hyps;;

(* ------------------------------------------------------------------------- *)
(* Import "external" theorem.                                                *)
(* ------------------------------------------------------------------------- *)

let using ths p g =
  let ths' = map (fun th -> itlist gen (fv(concl th)) th) ths in
  map concl ths',fun () -> map (assumptate g) ths';;

(* ------------------------------------------------------------------------- *)
(* Trivial justification, producing no hypotheses.                           *)
(* ------------------------------------------------------------------------- *)

let at once p gl = ([],fun x -> []) and once = [];;

(* ------------------------------------------------------------------------- *)
(* Main actions on canonical theorem.                                        *)
(* ------------------------------------------------------------------------- *)

let have s byfn hyps (Goals((asl,w)::gls,jfn) as gl) =
  let (l,p) = parse_labelled_formula gl s in
  let th = justify byfn hyps gl p in
  let mfn = if asl = [] then fun pth -> imp_trans (th()) pth
            else fun pth -> imp_unduplicate
                              (imp_trans (th()) (shunt pth)) in
  Goals(((l,p)::asl,w)::gls,jmodify jfn mfn);;

let case_split s byfn hyps (Goals((asl,w)::gls,jfn) as gl) =
  let (l,(Or(p,q) as fm)) = parse_labelled_formula gl s in
  let th = justify byfn hyps gl fm in
  let jfn' (pth::qth::ths) =
    let th1 = ante_disj (shunt pth) (shunt qth) in
    let th2 = imp_unduplicate(imp_trans (th()) th1) in
    jfn (th2::ths) in
  Goals(((l,p)::asl,w)::((l,q)::asl,w)::gls,jfn');;

let consider (a,s) byfn hyps (Goals((asl,w)::gls,jfn) as gl) =
  if exists (mem a ** fv) (w::map snd asl)
  then failwith "consider: variable free in assumptions" else
  let (l,p) = parse_labelled_formula gl s in
  let th = justify byfn hyps gl (Exists(a,p)) in
  let jfn' = jmodify jfn (fun pth ->
    imp_unduplicate (imp_trans (th()) (exists_left a (shunt pth)))) in
  Goals(((l,p)::asl,w)::gls,jfn');;

(* ------------------------------------------------------------------------- *)
(* Thesis modification.                                                      *)
(* ------------------------------------------------------------------------- *)

let modifythesis fm thesis =
  if fm = thesis then (True,fun fth tth -> fth) else
  match thesis with
    And(p,q) ->
        if fm <> p
        then failwith "modifythesis: doesn't match first conjunct" else
        (q,fun pth qth -> imp_trans_chain [pth; qth] (and_pair p q))
  | Iff(p,q) ->
        if fm <> Imp(p,q)
        then failwith "modifythesis: doesn't match implication" else
        (Imp(q,p),fun pth qth -> imp_trans_chain [pth; qth]
                                                 (axiom_impiff p q))
  | _ -> failwith "modifythesis: don't have anything to do";;

(* ------------------------------------------------------------------------- *)
(* Terminating steps.                                                        *)
(* ------------------------------------------------------------------------- *)

let thus s byfn hyps gl0 =
  let (Goals((asl,w)::gls,jfn) as gl) = have s byfn hyps gl0 in
  let fm = snd(hd asl) in
  let (p,rfn) = modifythesis fm w in
  Goals((asl,p)::gls,fun (pth::ths) ->
        let th = firstassum asl in jfn((rfn th pth) :: ths));;

let qed (Goals((asl,w)::gls,jfn) as gl) =
  if w <> True then failwith "qed: unproven thesis" else
  Goals(gls,fun ths -> jfn(truth :: ths));;

(* ------------------------------------------------------------------------- *)
(* The "so" continuation.                                                    *)
(* ------------------------------------------------------------------------- *)

let so cont args byfn hyps (Goals((asl,w)::gls,jfn) as gl) =
  cont args (fun hyps p g ->
          let ps,ths = byfn hyps p g in
          snd(hd asl)::ps,fun () -> firstassum asl::ths())
        hyps gl;;

let hence args byfn hyps gl = so thus args byfn hyps gl;;

(* ------------------------------------------------------------------------- *)
(* Nested sub-proof using same model.                                        *)
(* ------------------------------------------------------------------------- *)

let proof prf p (Goals((asl,w)::gls,jfn)) =
  match (run prf (Goals([asl,p],hd))) with
    Goals([],fn) -> [p],(fun () -> [fn []])
  | _ -> failwith "unsolved goals in nested proof";;

(* ------------------------------------------------------------------------- *)
(* General proof construct.                                                  *)
(* ------------------------------------------------------------------------- *)

let prove fm prf =
  let gls = run prf (set_goal fm) in
  print_string "Goals proved; reconstructing..."; print_newline();
  extract_thm gls;;

(* ------------------------------------------------------------------------- *)
(* Examples.                                                                 *)
(* ------------------------------------------------------------------------- *)

(*
prove <<p(a) ==> (forall x. p(x) ==> p(f(x)))
        ==> exists y. p(y) /\ p(f(y))>>
      [thus thesis at once;
       qed];;

prove
 <<(forall x. x <= x) /\
   (forall x y z. x <= y /\ y <= z ==> x <= z) /\
   (forall x y. f(x) <= y <=> x <= g(y))
   ==> (forall x y. x <= y ==> f(x) <= f(y)) /\
       (forall x y. x <= y ==> g(x) <= g(y))>>
  [assume "A: antecedent";
   hence "forall x y. x <= y ==> f(x) <= f(y)" at once;
   thus thesis by ["A"];
   qed];;

prove
 <<(exists x. p(x)) ==> (forall x. p(x) ==> p(f(x)))
   ==> exists y. p(f(f(f(f(y)))))>>
  [assume "A: exists x. p(x)";
   assume "B: forall x. p(x) ==> p(f(x))";
   have "C: forall x. p(x) ==> p(f(f(f(f(x)))))" proof
    [have "forall x. p(x) ==> p(f(f(x)))" by ["B"];
     hence thesis at once;
     qed];
   consider ("a","p(a)") by ["A"];
   take "a";
   hence thesis by ["C"];
   qed];;

*)
