diff --git a/src/HOL/Data_Structures/Array_Braun.thy b/src/HOL/Data_Structures/Array_Braun.thy --- a/src/HOL/Data_Structures/Array_Braun.thy +++ b/src/HOL/Data_Structures/Array_Braun.thy @@ -1,663 +1,731 @@ (* Author: Tobias Nipkow, with contributions by Thomas Sewell *) section "Arrays via Braun Trees" theory Array_Braun - imports - Array_Specs - Braun_Tree +imports + Time_Funs + Array_Specs + Braun_Tree begin subsection "Array" fun lookup1 :: "'a tree \ nat \ 'a" where "lookup1 (Node l x r) n = (if n=1 then x else lookup1 (if even n then l else r) (n div 2))" fun update1 :: "nat \ 'a \ 'a tree \ 'a tree" where "update1 n x Leaf = Node Leaf x Leaf" | "update1 n x (Node l a r) = (if n=1 then Node l x r else if even n then Node (update1 (n div 2) x l) a r else Node l a (update1 (n div 2) x r))" fun adds :: "'a list \ nat \ 'a tree \ 'a tree" where "adds [] n t = t" | "adds (x#xs) n t = adds xs (n+1) (update1 (n+1) x t)" fun list :: "'a tree \ 'a list" where "list Leaf = []" | "list (Node l x r) = x # splice (list l) (list r)" subsubsection "Functional Correctness" lemma size_list: "size(list t) = size t" by(induction t)(auto) lemma minus1_div2: "(n - Suc 0) div 2 = (if odd n then n div 2 else n div 2 - 1)" by auto arith lemma nth_splice: "\ n < size xs + size ys; size ys \ size xs; size xs \ size ys + 1 \ \ splice xs ys ! n = (if even n then xs else ys) ! (n div 2)" proof(induction xs ys arbitrary: n rule: splice.induct) qed (auto simp: nth_Cons' minus1_div2) lemma div2_in_bounds: "\ braun (Node l x r); n \ {1..size(Node l x r)}; n > 1 \ \ (odd n \ n div 2 \ {1..size r}) \ (even n \ n div 2 \ {1..size l})" by auto arith declare upt_Suc[simp del] paragraph \\<^const>\lookup1\\ lemma nth_list_lookup1: "\braun t; i < size t\ \ list t ! i = lookup1 t (i+1)" proof(induction t arbitrary: i) case Leaf thus ?case by simp next case Node thus ?case using div2_in_bounds[OF Node.prems(1), of "i+1"] by (auto simp: nth_splice minus1_div2 size_list) qed lemma list_eq_map_lookup1: "braun t \ list t = map (lookup1 t) [1..\<^const>\update1\\ lemma size_update1: "\ braun t; n \ {1.. size t} \ \ size(update1 n x t) = size t" proof(induction t arbitrary: n) case Leaf thus ?case by simp next case Node thus ?case using div2_in_bounds[OF Node.prems] by simp qed lemma braun_update1: "\braun t; n \ {1.. size t} \ \ braun(update1 n x t)" proof(induction t arbitrary: n) case Leaf thus ?case by simp next case Node thus ?case using div2_in_bounds[OF Node.prems] by (simp add: size_update1) qed lemma lookup1_update1: "\ braun t; n \ {1.. size t} \ \ lookup1 (update1 n x t) m = (if n=m then x else lookup1 t m)" proof(induction t arbitrary: m n) case Leaf then show ?case by simp next have aux: "\ odd n; odd m \ \ n div 2 = (m::nat) div 2 \ m=n" for m n using odd_two_times_div_two_succ by fastforce case Node thus ?case using div2_in_bounds[OF Node.prems] by (auto simp: aux) qed lemma list_update1: "\ braun t; n \ {1.. size t} \ \ list(update1 n x t) = (list t)[n-1 := x]" by(auto simp add: list_eq_map_lookup1 list_eq_iff_nth_eq lookup1_update1 size_update1 braun_update1) text \A second proof of @{thm list_update1}:\ lemma diff1_eq_iff: "n > 0 \ n - Suc 0 = m \ n = m+1" by arith lemma list_update_splice: "\ n < size xs + size ys; size ys \ size xs; size xs \ size ys + 1 \ \ (splice xs ys) [n := x] = (if even n then splice (xs[n div 2 := x]) ys else splice xs (ys[n div 2 := x]))" by(induction xs ys arbitrary: n rule: splice.induct) (auto split: nat.split) lemma list_update2: "\ braun t; n \ {1.. size t} \ \ list(update1 n x t) = (list t)[n-1 := x]" proof(induction t arbitrary: n) case Leaf thus ?case by simp next case (Node l a r) thus ?case using div2_in_bounds[OF Node.prems] by(auto simp: list_update_splice diff1_eq_iff size_list split: nat.split) qed paragraph \\<^const>\adds\\ lemma splice_last: shows "size ys \ size xs \ splice (xs @ [x]) ys = splice xs ys @ [x]" and "size ys+1 \ size xs \ splice xs (ys @ [y]) = splice xs ys @ [y]" by(induction xs ys arbitrary: x y rule: splice.induct) (auto) lemma list_add_hi: "braun t \ list(update1 (Suc(size t)) x t) = list t @ [x]" by(induction t)(auto simp: splice_last size_list) lemma size_add_hi: "braun t \ m = size t \ size(update1 (Suc m) x t) = size t + 1" by(induction t arbitrary: m)(auto) lemma braun_add_hi: "braun t \ braun(update1 (Suc(size t)) x t)" by(induction t)(auto simp: size_add_hi) lemma size_braun_adds: "\ braun t; size t = n \ \ size(adds xs n t) = size t + length xs \ braun (adds xs n t)" by(induction xs arbitrary: t n)(auto simp: braun_add_hi size_add_hi) lemma list_adds: "\ braun t; size t = n \ \ list(adds xs n t) = list t @ xs" by(induction xs arbitrary: t n)(auto simp: size_braun_adds list_add_hi size_add_hi braun_add_hi) subsubsection "Array Implementation" interpretation A: Array where lookup = "\(t,l) n. lookup1 t (n+1)" and update = "\n x (t,l). (update1 (n+1) x t, l)" and len = "\(t,l). l" and array = "\xs. (adds xs 0 Leaf, length xs)" and invar = "\(t,l). braun t \ l = size t" and list = "\(t,l). list t" proof (standard, goal_cases) case 1 thus ?case by (simp add: nth_list_lookup1 split: prod.splits) next case 2 thus ?case by (simp add: list_update1 split: prod.splits) next case 3 thus ?case by (simp add: size_list split: prod.splits) next case 4 thus ?case by (simp add: list_adds) next case 5 thus ?case by (simp add: braun_update1 size_update1 split: prod.splits) next case 6 thus ?case by (simp add: size_braun_adds split: prod.splits) qed subsection "Flexible Array" fun add_lo where "add_lo x Leaf = Node Leaf x Leaf" | "add_lo x (Node l a r) = Node (add_lo a r) x l" fun merge where "merge Leaf r = r" | "merge (Node l a r) rr = Node rr a (merge l r)" fun del_lo where "del_lo Leaf = Leaf" | "del_lo (Node l a r) = merge l r" fun del_hi :: "nat \ 'a tree \ 'a tree" where "del_hi n Leaf = Leaf" | "del_hi n (Node l x r) = (if n = 1 then Leaf else if even n then Node (del_hi (n div 2) l) x r else Node l x (del_hi (n div 2) r))" subsubsection "Functional Correctness" paragraph \\<^const>\add_lo\\ lemma list_add_lo: "braun t \ list (add_lo a t) = a # list t" by(induction t arbitrary: a) auto lemma braun_add_lo: "braun t \ braun(add_lo x t)" by(induction t arbitrary: x) (auto simp add: list_add_lo simp flip: size_list) paragraph \\<^const>\del_lo\\ lemma list_merge: "braun (Node l x r) \ list(merge l r) = splice (list l) (list r)" by (induction l r rule: merge.induct) auto lemma braun_merge: "braun (Node l x r) \ braun(merge l r)" by (induction l r rule: merge.induct)(auto simp add: list_merge simp flip: size_list) lemma list_del_lo: "braun t \ list(del_lo t) = tl (list t)" by (cases t) (simp_all add: list_merge) lemma braun_del_lo: "braun t \ braun(del_lo t)" by (cases t) (simp_all add: braun_merge) paragraph \\<^const>\del_hi\\ lemma list_Nil_iff: "list t = [] \ t = Leaf" by(cases t) simp_all lemma butlast_splice: "butlast (splice xs ys) = (if size xs > size ys then splice (butlast xs) ys else splice xs (butlast ys))" by(induction xs ys rule: splice.induct) (auto) lemma list_del_hi: "braun t \ size t = st \ list(del_hi st t) = butlast(list t)" by (induction t arbitrary: st) (auto simp: list_Nil_iff size_list butlast_splice) lemma braun_del_hi: "braun t \ size t = st \ braun(del_hi st t)" by (induction t arbitrary: st) (auto simp: list_del_hi simp flip: size_list) subsubsection "Flexible Array Implementation" interpretation AF: Array_Flex where lookup = "\(t,l) n. lookup1 t (n+1)" and update = "\n x (t,l). (update1 (n+1) x t, l)" and len = "\(t,l). l" and array = "\xs. (adds xs 0 Leaf, length xs)" and invar = "\(t,l). braun t \ l = size t" and list = "\(t,l). list t" and add_lo = "\x (t,l). (add_lo x t, l+1)" and del_lo = "\(t,l). (del_lo t, l-1)" and add_hi = "\x (t,l). (update1 (Suc l) x t, l+1)" and del_hi = "\(t,l). (del_hi l t, l-1)" proof (standard, goal_cases) case 1 thus ?case by (simp add: list_add_lo split: prod.splits) next case 2 thus ?case by (simp add: list_del_lo split: prod.splits) next case 3 thus ?case by (simp add: list_add_hi braun_add_hi split: prod.splits) next case 4 thus ?case by (simp add: list_del_hi split: prod.splits) next case 5 thus ?case by (simp add: braun_add_lo list_add_lo flip: size_list split: prod.splits) next case 6 thus ?case by (simp add: braun_del_lo list_del_lo flip: size_list split: prod.splits) next case 7 thus ?case by (simp add: size_add_hi braun_add_hi split: prod.splits) next case 8 thus ?case by (simp add: braun_del_hi list_del_hi flip: size_list split: prod.splits) qed subsection "Faster" subsubsection \Size\ fun diff :: "'a tree \ nat \ nat" where "diff Leaf _ = 0" | "diff (Node l x r) n = (if n=0 then 1 else if even n then diff r (n div 2 - 1) else diff l (n div 2))" fun size_fast :: "'a tree \ nat" where "size_fast Leaf = 0" | "size_fast (Node l x r) = (let n = size_fast r in 1 + 2*n + diff l n)" declare Let_def[simp] lemma diff: "braun t \ size t : {n, n + 1} \ diff t n = size t - n" by (induction t arbitrary: n) auto lemma size_fast: "braun t \ size_fast t = size t" by (induction t) (auto simp add: diff) subsubsection \Initialization with 1 element\ fun braun_of_naive :: "'a \ nat \ 'a tree" where "braun_of_naive x n = (if n=0 then Leaf else let m = (n-1) div 2 in if odd n then Node (braun_of_naive x m) x (braun_of_naive x m) else Node (braun_of_naive x (m + 1)) x (braun_of_naive x m))" fun braun2_of :: "'a \ nat \ 'a tree * 'a tree" where "braun2_of x n = (if n = 0 then (Leaf, Node Leaf x Leaf) else let (s,t) = braun2_of x ((n-1) div 2) in if odd n then (Node s x s, Node t x s) else (Node t x s, Node t x t))" definition braun_of :: "'a \ nat \ 'a tree" where "braun_of x n = fst (braun2_of x n)" declare braun2_of.simps [simp del] lemma braun2_of_size_braun: "braun2_of x n = (s,t) \ size s = n \ size t = n+1 \ braun s \ braun t" proof(induction x n arbitrary: s t rule: braun2_of.induct) case (1 x n) then show ?case by (auto simp: braun2_of.simps[of x n] split: prod.splits if_splits) presburger+ qed lemma braun2_of_replicate: "braun2_of x n = (s,t) \ list s = replicate n x \ list t = replicate (n+1) x" proof(induction x n arbitrary: s t rule: braun2_of.induct) case (1 x n) have "x # replicate m x = replicate (m+1) x" for m by simp with 1 show ?case apply (auto simp: braun2_of.simps[of x n] replicate.simps(2)[of 0 x] simp del: replicate.simps(2) split: prod.splits if_splits) by presburger+ qed corollary braun_braun_of: "braun(braun_of x n)" unfolding braun_of_def by (metis eq_fst_iff braun2_of_size_braun) corollary list_braun_of: "list(braun_of x n) = replicate n x" unfolding braun_of_def by (metis eq_fst_iff braun2_of_replicate) subsubsection "Proof Infrastructure" text \Originally due to Thomas Sewell.\ paragraph \\take_nths\\ fun take_nths :: "nat \ nat \ 'a list \ 'a list" where "take_nths i k [] = []" | "take_nths i k (x # xs) = (if i = 0 then x # take_nths (2^k - 1) k xs else take_nths (i - 1) k xs)" text \This is the more concise definition but seems to complicate the proofs:\ lemma take_nths_eq_nths: "take_nths i k xs = nths xs (\n. {n*2^k + i})" proof(induction xs arbitrary: i) case Nil then show ?case by simp next case (Cons x xs) show ?case proof cases assume [simp]: "i = 0" have "\x n. Suc x = n * 2 ^ k \ \xa. x = Suc xa * 2 ^ k - Suc 0" by (metis diff_Suc_Suc diff_zero mult_eq_0_iff not0_implies_Suc) then have "(\n. {(n+1) * 2 ^ k - 1}) = {m. \n. Suc m = n * 2 ^ k}" by (auto simp del: mult_Suc) thus ?thesis by (simp add: Cons.IH ac_simps nths_Cons) next assume [arith]: "i \ 0" have "\x n. Suc x = n * 2 ^ k + i \ \xa. x = xa * 2 ^ k + i - Suc 0" by (metis diff_Suc_Suc diff_zero) then have "(\n. {n * 2 ^ k + i - 1}) = {m. \n. Suc m = n * 2 ^ k + i}" by auto thus ?thesis by (simp add: Cons.IH nths_Cons) qed qed lemma take_nths_drop: "take_nths i k (drop j xs) = take_nths (i + j) k xs" by (induct xs arbitrary: i j; simp add: drop_Cons split: nat.split) lemma take_nths_00: "take_nths 0 0 xs = xs" by (induct xs; simp) lemma splice_take_nths: "splice (take_nths 0 (Suc 0) xs) (take_nths (Suc 0) (Suc 0) xs) = xs" by (induct xs; simp) lemma take_nths_take_nths: "take_nths i m (take_nths j n xs) = take_nths ((i * 2^n) + j) (m + n) xs" by (induct xs arbitrary: i j; simp add: algebra_simps power_add) lemma take_nths_empty: "(take_nths i k xs = []) = (length xs \ i)" by (induction xs arbitrary: i k) auto lemma hd_take_nths: "i < length xs \ hd(take_nths i k xs) = xs ! i" by (induction xs arbitrary: i k) auto lemma take_nths_01_splice: "\ length xs = length ys \ length xs = length ys + 1 \ \ take_nths 0 (Suc 0) (splice xs ys) = xs \ take_nths (Suc 0) (Suc 0) (splice xs ys) = ys" by (induct xs arbitrary: ys; case_tac ys; simp) lemma length_take_nths_00: "length (take_nths 0 (Suc 0) xs) = length (take_nths (Suc 0) (Suc 0) xs) \ length (take_nths 0 (Suc 0) xs) = length (take_nths (Suc 0) (Suc 0) xs) + 1" by (induct xs) auto paragraph \\braun_list\\ fun braun_list :: "'a tree \ 'a list \ bool" where "braun_list Leaf xs = (xs = [])" | "braun_list (Node l x r) xs = (xs \ [] \ x = hd xs \ braun_list l (take_nths 1 1 xs) \ braun_list r (take_nths 2 1 xs))" lemma braun_list_eq: "braun_list t xs = (braun t \ xs = list t)" proof (induct t arbitrary: xs) case Leaf show ?case by simp next case Node show ?case using length_take_nths_00[of xs] splice_take_nths[of xs] by (auto simp: neq_Nil_conv Node.hyps size_list[symmetric] take_nths_01_splice) qed subsubsection \Converting a list of elements into a Braun tree\ fun nodes :: "'a tree list \ 'a list \ 'a tree list \ 'a tree list" where "nodes (l#ls) (x#xs) (r#rs) = Node l x r # nodes ls xs rs" | "nodes (l#ls) (x#xs) [] = Node l x Leaf # nodes ls xs []" | "nodes [] (x#xs) (r#rs) = Node Leaf x r # nodes [] xs rs" | "nodes [] (x#xs) [] = Node Leaf x Leaf # nodes [] xs []" | "nodes ls [] rs = []" fun brauns :: "nat \ 'a list \ 'a tree list" where "brauns k xs = (if xs = [] then [] else let ys = take (2^k) xs; zs = drop (2^k) xs; ts = brauns (k+1) zs in nodes ts ys (drop (2^k) ts))" declare brauns.simps[simp del] definition brauns1 :: "'a list \ 'a tree" where "brauns1 xs = (if xs = [] then Leaf else brauns 0 xs ! 0)" -fun T_brauns :: "nat \ 'a list \ nat" where - "T_brauns k xs = (if xs = [] then 0 else - let ys = take (2^k) xs; - zs = drop (2^k) xs; - ts = brauns (k+1) zs - in 4 * min (2^k) (length xs) + T_brauns (k+1) zs)" - paragraph "Functional correctness" text \The proof is originally due to Thomas Sewell.\ lemma length_nodes: "length (nodes ls xs rs) = length xs" by (induct ls xs rs rule: nodes.induct; simp) lemma nth_nodes: "i < length xs \ nodes ls xs rs ! i = Node (if i < length ls then ls ! i else Leaf) (xs ! i) (if i < length rs then rs ! i else Leaf)" by (induct ls xs rs arbitrary: i rule: nodes.induct; simp add: nth_Cons split: nat.split) theorem length_brauns: "length (brauns k xs) = min (length xs) (2 ^ k)" proof (induct xs arbitrary: k rule: measure_induct_rule[where f=length]) case (less xs) thus ?case by (simp add: brauns.simps[of k xs] length_nodes) qed theorem brauns_correct: "i < min (length xs) (2 ^ k) \ braun_list (brauns k xs ! i) (take_nths i k xs)" proof (induct xs arbitrary: i k rule: measure_induct_rule[where f=length]) case (less xs) have "xs \ []" using less.prems by auto let ?zs = "drop (2^k) xs" let ?ts = "brauns (Suc k) ?zs" from less.hyps[of ?zs _ "Suc k"] have IH: "\ j = i + 2 ^ k; i < min (length ?zs) (2 ^ (k+1)) \ \ braun_list (?ts ! i) (take_nths j (Suc k) xs)" for i j using \xs \ []\ by (simp add: take_nths_drop) show ?case using less.prems by (auto simp: brauns.simps[of k xs] nth_nodes take_nths_take_nths IH take_nths_empty hd_take_nths length_brauns) qed corollary brauns1_correct: "braun (brauns1 xs) \ list (brauns1 xs) = xs" using brauns_correct[of 0 xs 0] by (simp add: brauns1_def braun_list_eq take_nths_00) paragraph "Running Time Analysis" -theorem T_brauns: - "T_brauns k xs = 4 * length xs" +time_fun_0 "(^)" + +time_fun nodes + +lemma T_nodes: "T_nodes ls xs rs = length xs + 1" +by(induction ls xs rs rule: T_nodes.induct) auto + +time_fun brauns + +lemma T_brauns_pretty: "T_brauns k xs = (if xs = [] then 0 else + let ys = take (2^k) xs; + zs = drop (2^k) xs; + ts = brauns (k+1) zs + in T_take (2 ^ k) xs + T_drop (2 ^ k) xs + T_brauns (k + 1) zs + T_drop (2 ^ k) ts + T_nodes ts ys (drop (2 ^ k) ts)) + 1" +by(simp) + +lemma T_brauns_simple: "T_brauns k xs = (if xs = [] then 0 else + 3 * (min (2^k) (length xs) + 1) + (min (2^k) (length xs - 2^k) + 1) + T_brauns (k+1) (drop (2^k) xs)) + 1" +by(simp add: T_nodes T_take_eq T_drop_eq length_brauns min_def) + +theorem T_brauns_ub: + "T_brauns k xs \ 9 * (length xs + 1)" proof (induction xs arbitrary: k rule: measure_induct_rule[where f = length]) case (less xs) show ?case proof cases assume "xs = []" thus ?thesis by(simp) next assume "xs \ []" - let ?zs = "drop (2^k) xs" - have "T_brauns k xs = T_brauns (k+1) ?zs + 4 * min (2^k) (length xs)" - using \xs \ []\ by(simp) - also have "\ = 4 * length ?zs + 4 * min (2^k) (length xs)" + let ?n = "length xs" let ?zs = "drop (2^k) xs" + have *: "?n - 2^k + 1 \ ?n" + using \xs \ []\ less_eq_Suc_le by fastforce + have "T_brauns k xs = + 3 * (min (2^k) ?n + 1) + (min (2^k) (?n - 2^k) + 1) + T_brauns (k+1) ?zs + 1" + unfolding T_brauns_simple[of k xs] using \xs \ []\ by(simp del: T_brauns.simps) + also have "\ \ 4 * min (2^k) ?n + T_brauns (k+1) ?zs + 5" + by(simp add: min_def) + also have "\ \ 4 * min (2^k) ?n + 9 * (length ?zs + 1) + 5" using less[of ?zs "k+1"] \xs \ []\ - by (simp) - also have "\ = 4 * length xs" + by (simp del: T_brauns.simps) + also have "\ = 4 * min (2^k) ?n + 9 * (?n - 2^k + 1) + 5" by(simp) - finally show ?case . + also have "\ = 4 * min (2^k) ?n + 4 * (?n - 2^k) + 5 * (?n - 2^k + 1) + 9" + by(simp) + also have "\ = 4 * ?n + 5 * (?n - 2^k + 1) + 9" + by(simp) + also have "\ \ 4 * ?n + 5 * ?n + 9" + using * by(simp) + also have "\ = 9 * (?n + 1)" + by (simp add: Suc_leI) + finally show ?thesis . qed qed subsubsection \Converting a Braun Tree into a List of Elements\ text \The code and the proof are originally due to Thomas Sewell (except running time).\ function list_fast_rec :: "'a tree list \ 'a list" where "list_fast_rec ts = (let us = filter (\t. t \ Leaf) ts in if us = [] then [] else map value us @ list_fast_rec (map left us @ map right us))" by (pat_completeness, auto) lemma list_fast_rec_term1: "ts \ [] \ Leaf \ set ts \ sum_list (map (size o left) ts) + sum_list (map (size o right) ts) < sum_list (map size ts)" apply (clarsimp simp: sum_list_addf[symmetric] sum_list_map_filter') apply (rule sum_list_strict_mono; clarsimp?) apply (case_tac x; simp) done lemma list_fast_rec_term: "us \ [] \ us = filter (\t. t \ \\) ts \ sum_list (map (size o left) us) + sum_list (map (size o right) us) < sum_list (map size ts)" apply (rule order_less_le_trans, rule list_fast_rec_term1, simp_all) apply (rule sum_list_filter_le_nat) done termination by (relation "measure (sum_list o map size)"; simp add: list_fast_rec_term) declare list_fast_rec.simps[simp del] definition list_fast :: "'a tree \ 'a list" where "list_fast t = list_fast_rec [t]" -function T_list_fast_rec :: "'a tree list \ nat" where +(* TODO: map and filter are a problem! +- The automatically generated T_map is slightly more complicated than needed. +- We cannot use the manually defined T_map directly because the automatic translation + assumes that T_map has a more complicated type and generates a "wrong" call. +Therefore we hide map/filter at the moment. +*) + +definition "filter_not_Leaf = filter (\t. t \ Leaf)" + +definition "map_left = map left" +definition "map_right = map right" +definition "map_value = map value" + +definition "T_filter_not_Leaf ts = length ts + 1" +definition "T_map_left ts = length ts + 1" +definition "T_map_right ts = length ts + 1" +definition "T_map_value ts = length ts + 1" +(* +time_fun "tree.value" +time_fun "left" +time_fun "right" +*) + +lemmas defs = filter_not_Leaf_def map_left_def map_right_def map_value_def + T_filter_not_Leaf_def T_map_value_def T_map_left_def T_map_right_def + +(* A variant w/o explicit map/filter; T_list_fast_rec is generated from it *) +lemma list_fast_rec_simp: +"list_fast_rec ts = (let us = filter_not_Leaf ts in + if us = [] then [] else + map_value us @ list_fast_rec (map_left us @ map_right us))" +unfolding defs list_fast_rec.simps[of ts] by(rule refl) + +time_function list_fast_rec equations list_fast_rec_simp +termination + by (relation "measure (sum_list o map size)"; simp add: list_fast_rec_term defs) + +lemma T_list_fast_rec_pretty: "T_list_fast_rec ts = (let us = filter (\t. t \ Leaf) ts - in length ts + (if us = [] then 0 else - 5 * length us + T_list_fast_rec (map left us @ map right us)))" - by (pat_completeness, auto) - -termination - by (relation "measure (sum_list o map size)"; simp add: list_fast_rec_term) + in length ts + 1 + (if us = [] then 0 else + 5 * (length us + 1) + T_list_fast_rec (map left us @ map right us))) + 1" +unfolding defs T_list_fast_rec.simps[of ts] +by(simp add: T_append) declare T_list_fast_rec.simps[simp del] + paragraph "Functional Correctness" lemma list_fast_rec_all_Leaf: "\t \ set ts. t = Leaf \ list_fast_rec ts = []" by (simp add: filter_empty_conv list_fast_rec.simps) lemma take_nths_eq_single: "length xs - i < 2^n \ take_nths i n xs = take 1 (drop i xs)" by (induction xs arbitrary: i n; simp add: drop_Cons') lemma braun_list_Nil: "braun_list t [] = (t = Leaf)" by (cases t; simp) lemma braun_list_not_Nil: "xs \ [] \ braun_list t xs = (\l x r. t = Node l x r \ x = hd xs \ braun_list l (take_nths 1 1 xs) \ braun_list r (take_nths 2 1 xs))" by(cases t; simp) theorem list_fast_rec_correct: "\ length ts = 2 ^ k; \i < 2 ^ k. braun_list (ts ! i) (take_nths i k xs) \ \ list_fast_rec ts = xs" proof (induct xs arbitrary: k ts rule: measure_induct_rule[where f=length]) case (less xs) show ?case proof (cases "length xs < 2 ^ k") case True from less.prems True have filter: "\n. ts = map (\x. Node Leaf x Leaf) xs @ replicate n Leaf" apply (rule_tac x="length ts - length xs" in exI) apply (clarsimp simp: list_eq_iff_nth_eq) apply(auto simp: nth_append braun_list_not_Nil take_nths_eq_single braun_list_Nil hd_drop_conv_nth) done thus ?thesis by (clarsimp simp: list_fast_rec.simps[of ts] o_def list_fast_rec_all_Leaf) next case False with less.prems(2) have *: "\i < 2 ^ k. ts ! i \ Leaf \ value (ts ! i) = xs ! i \ braun_list (left (ts ! i)) (take_nths (i + 2 ^ k) (Suc k) xs) \ (\ys. ys = take_nths (i + 2 * 2 ^ k) (Suc k) xs \ braun_list (right (ts ! i)) ys)" by (auto simp: take_nths_empty hd_take_nths braun_list_not_Nil take_nths_take_nths algebra_simps) have 1: "map value ts = take (2 ^ k) xs" using less.prems(1) False by (simp add: list_eq_iff_nth_eq *) have 2: "list_fast_rec (map left ts @ map right ts) = drop (2 ^ k) xs" using less.prems(1) False by (auto intro!: Nat.diff_less less.hyps[where k= "Suc k"] simp: nth_append * take_nths_drop algebra_simps) from less.prems(1) False show ?thesis by (auto simp: list_fast_rec.simps[of ts] 1 2 * all_set_conv_all_nth) qed qed corollary list_fast_correct: "braun t \ list_fast t = list t" by (simp add: list_fast_def take_nths_00 braun_list_eq list_fast_rec_correct[where k=0]) + paragraph "Running Time Analysis" lemma sum_tree_list_children: "\t \ set ts. t \ Leaf \ (\t\ts. k * size t) = (\t \ map left ts @ map right ts. k * size t) + k * length ts" by(induction ts)(auto simp add: neq_Leaf_iff algebra_simps) theorem T_list_fast_rec_ub: - "T_list_fast_rec ts \ sum_list (map (\t. 7*size t + 1) ts)" + "T_list_fast_rec ts \ sum_list (map (\t. 14*size t + 1) ts) + 2" proof (induction ts rule: measure_induct_rule[where f="sum_list o map size"]) case (less ts) let ?us = "filter (\t. t \ Leaf) ts" show ?case proof cases assume "?us = []" thus ?thesis using T_list_fast_rec.simps[of ts] - by(simp add: sum_list_Suc) + by(simp add: defs sum_list_Suc) next assume "?us \ []" let ?children = "map left ?us @ map right ?us" - have "T_list_fast_rec ts = T_list_fast_rec ?children + 5 * length ?us + length ts" - using \?us \ []\ T_list_fast_rec.simps[of ts] by(simp) - also have "\ \ (\t\?children. 7 * size t + 1) + 5 * length ?us + length ts" + have 1: "1 \ length ?us" + using \?us \ []\ linorder_not_less by auto + have "T_list_fast_rec ts = T_list_fast_rec ?children + 5 * length ?us + length ts + 7" + using \?us \ []\ T_list_fast_rec.simps[of ts] by(simp add: defs T_append) + also have "\ \ (\t\?children. 14 * size t + 1) + 5 * length ?us + length ts + 9" using less[of "?children"] list_fast_rec_term[of "?us"] \?us \ []\ by (simp) - also have "\ = (\t\?children. 7*size t) + 7 * length ?us + length ts" + also have "\ = (\t\?children. 14 * size t) + 7 * length ?us + length ts + 9" by(simp add: sum_list_Suc o_def) - also have "\ = (\t\?us. 7*size t) + length ts" + also have "\ \ (\t\?children. 14 * size t) + 14 * length ?us + length ts + 2" + using 1 by(simp add: sum_list_Suc o_def) + also have "\ = (\t\?us. 14 * size t) + length ts + 2" by(simp add: sum_tree_list_children) - also have "\ \ (\t\ts. 7*size t) + length ts" + also have "\ \ (\t\ts. 14 * size t) + length ts + 2" by(simp add: sum_list_filter_le_nat) - also have "\ = (\t\ts. 7 * size t + 1)" + also have "\ = (\t\ts. 14 * size t + 1) + 2" by(simp add: sum_list_Suc) finally show ?case . qed qed end diff --git a/src/HOL/Data_Structures/Define_Time_Function.ML b/src/HOL/Data_Structures/Define_Time_Function.ML --- a/src/HOL/Data_Structures/Define_Time_Function.ML +++ b/src/HOL/Data_Structures/Define_Time_Function.ML @@ -1,580 +1,582 @@ signature TIMING_FUNCTIONS = sig type 'a converter = { constc : local_theory -> term list -> (term -> 'a) -> term -> 'a, funcc : local_theory -> term list -> (term -> 'a) -> term -> term list -> 'a, ifc : local_theory -> term list -> (term -> 'a) -> typ -> term -> term -> term -> 'a, casec : local_theory -> term list -> (term -> 'a) -> term -> term list -> 'a, letc : local_theory -> term list -> (term -> 'a) -> typ -> term -> string list -> typ list -> term -> 'a }; val walk : local_theory -> term list -> 'a converter -> term -> 'a type pfunc = { names : string list, terms : term list, typs : typ list } val fun_pretty': Proof.context -> pfunc -> Pretty.T val fun_pretty: Proof.context -> Function.info -> Pretty.T val print_timing': Proof.context -> pfunc -> pfunc -> unit val print_timing: Proof.context -> Function.info -> Function.info -> unit val reg_and_proove_time_func: theory -> term list -> term list -> bool -> Function.info * theory val reg_time_func: theory -> term list -> term list -> bool -> theory val time_dom_tac: Proof.context -> thm -> thm list -> int -> tactic end structure Timing_Functions : TIMING_FUNCTIONS = struct (* Configure config variable to adjust the prefix *) val bprefix = Attrib.setup_config_string @{binding "time_prefix"} (K "T_") (* some default values to build terms easier *) val zero = Const (@{const_name "Groups.zero"}, HOLogic.natT) val one = Const (@{const_name "Groups.one"}, HOLogic.natT) (* Extracts terms from function info *) fun terms_of_info (info: Function.info) = let val {simps, ...} = info in map Thm.prop_of (case simps of SOME s => s | NONE => error "No terms of function found in info") end; type pfunc = { names : string list, terms : term list, typs : typ list } fun info_pfunc (info: Function.info): pfunc = let val {defname, fs, ...} = info; val T = case hd fs of (Const (_,T)) => T | _ => error "Internal error: Invalid info to print" in { names=[Binding.name_of defname], terms=terms_of_info info, typs=[T] } end (* Auxiliary functions for printing functions *) fun fun_pretty' ctxt (pfunc: pfunc) = let val {names, terms, typs} = pfunc; val header_beg = Pretty.str "fun "; fun prepHeadCont (nm,T) = [Pretty.str (nm ^ " :: "), (Pretty.quote (Syntax.pretty_typ ctxt T))] val header_content = List.concat (prepHeadCont (hd names,hd typs) :: map ((fn l => Pretty.str "\nand " :: l) o prepHeadCont) (ListPair.zip (tl names, tl typs))); val header_end = Pretty.str " where\n "; val header = [header_beg] @ header_content @ [header_end]; fun separate sep prts = flat (Library.separate [Pretty.str sep] (map single prts)); val ptrms = (separate "\n| " (map (Syntax.pretty_term ctxt) terms)); in Pretty.text_fold (header @ ptrms) end fun fun_pretty ctxt = fun_pretty' ctxt o info_pfunc fun print_timing' ctxt (opfunc: pfunc) (tpfunc: pfunc) = let val {names, ...} = opfunc; val poriginal = Pretty.item [Pretty.str "Original function:\n", fun_pretty' ctxt opfunc] val ptiming = Pretty.item [Pretty.str ("Running time function:\n"), fun_pretty' ctxt tpfunc] in Pretty.writeln (Pretty.text_fold [Pretty.str ("Converting " ^ (hd names) ^ (String.concat (map (fn nm => ", " ^ nm) (tl names))) ^ "\n"), poriginal, Pretty.str "\n", ptiming]) end fun print_timing ctxt (oinfo: Function.info) (tinfo: Function.info) = print_timing' ctxt (info_pfunc oinfo) (info_pfunc tinfo) val If_name = @{const_name "HOL.If"} val Let_name = @{const_name "HOL.Let"} (* returns true if it's an if term *) fun is_if (Const (n,_)) = (n = If_name) | is_if _ = false (* returns true if it's a case term *) fun is_case (Const (n,_)) = String.isPrefix "case_" (List.last (String.fields (fn s => s = #".") n)) | is_case _ = false (* returns true if it's a let term *) fun is_let (Const (n,_)) = (n = Let_name) | is_let _ = false (* change type of original function to new type (_ \ ... \ _ to _ \ ... \ nat) and replace all function arguments f with (t*T_f) *) fun change_typ (Type ("fun", [T1, T2])) = Type ("fun", [check_for_fun T1, change_typ T2]) | change_typ _ = HOLogic.natT and check_for_fun (f as Type ("fun", [_,_])) = HOLogic.mk_prodT (f, change_typ f) | check_for_fun (Type ("Product_Type.prod", [t1,t2])) = HOLogic.mk_prodT (check_for_fun t1, check_for_fun t2) | check_for_fun f = f (* Convert string name of function to its timing equivalent *) fun fun_name_to_time ctxt name = let val prefix = Config.get ctxt bprefix fun replace_last_name [n] = [prefix ^ n] | replace_last_name (n::ns) = n :: (replace_last_name ns) | replace_last_name _ = error "Internal error: Invalid function name to convert" val parts = String.fields (fn s => s = #".") name in String.concatWith "." (replace_last_name parts) end (* Count number of arguments of a function *) fun count_args (Type (n, [_,res])) = (if n = "fun" then 1 + count_args res else 0) | count_args _ = 0 (* Check if number of arguments matches function *) fun check_args s (Const (_,T), args) = (if length args = count_args T then () else error ("Partial applications/Lambdas not allowed (" ^ s ^ ")")) | check_args s (Free (_,T), args) = (if length args = count_args T then () else error ("Partial applications/Lambdas not allowed (" ^ s ^ ")")) | check_args s _ = error ("Partial applications/Lambdas not allowed (" ^ s ^ ")") (* Removes Abs *) fun rem_abs f (Abs (_,_,t)) = rem_abs f t | rem_abs f t = f t (* Map right side of equation *) fun map_r f (pT $ (eq $ l $ r)) = (pT $ (eq $ l $ f r)) | map_r _ _ = error "Internal error: No right side of equation found" (* Get left side of equation *) fun get_l (_ $ (_ $ l $ _)) = l | get_l _ = error "Internal error: No left side of equation found" (* Get right side of equation *) fun get_r (_ $ (_ $ _ $ r)) = r | get_r _ = error "Internal error: No right side of equation found" (* Return name of Const *) fun Const_name (Const (nm,_)) = SOME nm | Const_name _ = NONE fun time_term ctxt (Const (nm,T)) = let val T_nm = fun_name_to_time ctxt nm val T_T = change_typ T in (SOME (Syntax.check_term ctxt (Const (T_nm,T_T)))) handle (ERROR _) => case Syntax.read_term ctxt (Long_Name.base_name T_nm) of (Const (nm,_)) => SOME (Const (nm,T_T)) | _ => error ("Timing function of " ^ nm ^ " is not defined") end | time_term _ _ = error "Internal error: No valid function given" type 'a converter = { constc : local_theory -> term list -> (term -> 'a) -> term -> 'a, funcc : local_theory -> term list -> (term -> 'a) -> term -> term list -> 'a, ifc : local_theory -> term list -> (term -> 'a) -> typ -> term -> term -> term -> 'a, casec : local_theory -> term list -> (term -> 'a) -> term -> term list -> 'a, letc : local_theory -> term list -> (term -> 'a) -> typ -> term -> string list -> typ list -> term -> 'a }; (* Walks over term and calls given converter *) fun walk_func (t1 $ t2) ts = walk_func t1 (t2::ts) | walk_func t ts = (t, ts) fun build_func (f, []) = f | build_func (f, (t::ts)) = build_func (f$t, ts) fun walk_abs (Abs (nm,T,t)) nms Ts = walk_abs t (nm::nms) (T::Ts) | walk_abs t nms Ts = (t, nms, Ts) fun build_abs t (nm::nms) (T::Ts) = build_abs (Abs (nm,T,t)) nms Ts | build_abs t [] [] = t | build_abs _ _ _ = error "Internal error: Invalid terms to build abs" fun walk ctxt (origin: term list) (conv as {ifc, casec, funcc, letc, ...} : 'a converter) (t as _ $ _) = let val (f, args) = walk_func t [] val this = (walk ctxt origin conv) val _ = (case f of Abs _ => error "Lambdas not supported" | _ => ()) in (if is_if f then (case f of (Const (_,T)) => (case args of [cond, t, f] => ifc ctxt origin this T cond t f | _ => error "Partial applications not supported (if)") | _ => error "Internal error: invalid if term") else if is_case f then casec ctxt origin this f args else if is_let f then (case f of (Const (_,lT)) => (case args of [exp, t] => let val (t,nms,Ts) = walk_abs t [] [] in letc ctxt origin this lT exp nms Ts t end | _ => error "Partial applications not allowed (let)") | _ => error "Internal error: invalid let term") else funcc ctxt origin this f args) end | walk ctxt origin (conv as {constc, ...}) c = constc ctxt origin (walk ctxt origin conv) c (* 1. Fix all terms *) (* Exchange Var in types and terms to Free *) fun fixTerms (Var(ixn,T)) = Free (fst ixn, T) | fixTerms t = t fun fixTypes (TVar ((t, _), T)) = TFree (t, T) | fixTypes t = t fun noFun (Type ("fun",_)) = error "Functions in datatypes are not allowed in case constructions" | noFun _ = () fun casecBuildBounds n t = if n > 0 then casecBuildBounds (n-1) (t $ (Bound (n-1))) else t fun casecAbs ctxt f n (Type (_,[T,Tr])) (Abs (v,Ta,t)) = (noFun T; Abs (v,Ta,casecAbs ctxt f n Tr t)) | casecAbs ctxt f n (Type (Tn,[T,Tr])) t = (noFun T; case Variable.variant_fixes ["x"] ctxt of ([v],ctxt) => (if Tn = "fun" then Abs (v,T,casecAbs ctxt f (n + 1) Tr t) else f t) | _ => error "Internal error: could not fix variable") | casecAbs _ f n _ t = f (casecBuildBounds n (Term.incr_bv n 0 t)) fun fixCasecCases _ _ _ [t] = [t] | fixCasecCases ctxt f (Type (_,[T,Tr])) (t::ts) = casecAbs ctxt f 0 T t :: fixCasecCases ctxt f Tr ts | fixCasecCases _ _ _ _ = error "Internal error: invalid case types/terms" fun fixCasec ctxt _ f (t as Const (_,T)) args = (check_args "cases" (t,args); build_func (t,fixCasecCases ctxt f T args)) | fixCasec _ _ _ _ _ = error "Internal error: invalid case term" fun fixPartTerms ctxt (term: term list) t = let val _ = check_args "args" (walk_func (get_l t) []) in map_r (walk ctxt term { funcc = (fn _ => fn _ => fn f => fn t => fn args => (check_args "func" (t,args); build_func (t, map f args))), constc = (fn _ => fn _ => fn _ => fn c => (case c of Abs _ => error "Lambdas not supported" | _ => c)), ifc = (fn _ => fn _ => fn f => fn T => fn cond => fn tt => fn tf => ((Const (If_name, T)) $ f cond $ (f tt) $ (f tf))), casec = fixCasec, letc = (fn _ => fn _ => fn f => fn expT => fn exp => fn nms => fn Ts => fn t => let val f' = if length nms = 0 then (case f (t$exp) of t$_ => t | _ => error "Internal error: case could not be fixed (let)") else f t in (Const (Let_name,expT) $ (f exp) $ build_abs f' nms Ts) end) }) t end (* 2. Check if function is recursive *) fun or f (a,b) = f a orelse b fun find_rec ctxt term = (walk ctxt term { funcc = (fn _ => fn _ => fn f => fn t => fn args => List.exists (fn term => Const_name t = Const_name term) term orelse List.foldr (or f) false args), constc = (K o K o K o K) false, ifc = (fn _ => fn _ => fn f => fn _ => fn cond => fn tt => fn tf => f cond orelse f tt orelse f tf), casec = (fn _ => fn _ => fn f => fn t => fn cs => f t orelse List.foldr (or (rem_abs f)) false cs), letc = (fn _ => fn _ => fn f => fn _ => fn exp => fn _ => fn _ => fn t => f exp orelse f t) }) o get_r fun is_rec ctxt (term: term list) = List.foldr (or (find_rec ctxt term)) false (* 3. Convert equations *) (* Some Helper *) val plusTyp = @{typ "nat => nat => nat"} fun plus (SOME a) (SOME b) = SOME (Const (@{const_name "Groups.plus"}, plusTyp) $ a $ b) | plus (SOME a) NONE = SOME a | plus NONE (SOME b) = SOME b | plus NONE NONE = NONE fun opt_term NONE = HOLogic.zero | opt_term (SOME t) = t (* Converting of function term *) fun fun_to_time ctxt (origin: term list) (func as Const (nm,T)) = let val full_name_origin = map (fst o dest_Const) origin val prefix = Config.get ctxt bprefix in if List.exists (fn nm_orig => nm = nm_orig) full_name_origin then SOME (Free (prefix ^ Term.term_name func, change_typ T)) else if Zero_Funcs.is_zero (Proof_Context.theory_of ctxt) (nm,T) then NONE else time_term ctxt func end | fun_to_time _ _ (Free (nm,T)) = SOME (HOLogic.mk_snd (Free (nm,HOLogic.mk_prodT (T,change_typ T)))) | fun_to_time _ _ _ = error "Internal error: invalid function to convert" (* Convert arguments of left side of a term *) fun conv_arg _ _ (Free (nm,T as Type("fun",_))) = Free (nm, HOLogic.mk_prodT (T, change_typ T)) | conv_arg ctxt origin (f as Const (_,T as Type("fun",_))) = (Const (@{const_name "Product_Type.Pair"}, Type ("fun", [T,Type ("fun",[change_typ T, HOLogic.mk_prodT (T,change_typ T)])])) $ f $ (fun_to_time ctxt origin f |> Option.valOf)) | conv_arg ctxt origin ((Const ("Product_Type.Pair", _)) $ l $ r) = HOLogic.mk_prod (conv_arg ctxt origin l, conv_arg ctxt origin r) | conv_arg _ _ x = x fun conv_args ctxt origin = map (conv_arg ctxt origin) (* Handle function calls *) fun build_zero (Type ("fun", [T, R])) = Abs ("x", T, build_zero R) | build_zero _ = zero fun funcc_use_origin _ _ (Free (nm, T as Type ("fun",_))) = HOLogic.mk_fst (Free (nm,HOLogic.mk_prodT (T, change_typ T))) | funcc_use_origin _ _ t = t fun funcc_conv_arg ctxt origin (t as (_ $ _)) = map_aterms (funcc_use_origin ctxt origin) t | funcc_conv_arg _ _ (Free (nm, T as Type ("fun",_))) = (Free (nm, HOLogic.mk_prodT (T, change_typ T))) | funcc_conv_arg ctxt origin (f as Const (_,T as Type ("fun",_))) = (Const (@{const_name "Product_Type.Pair"}, Type ("fun", [T,Type ("fun",[change_typ T, HOLogic.mk_prodT (T,change_typ T)])])) $ f $ (Option.getOpt (fun_to_time ctxt origin f, build_zero T))) | funcc_conv_arg _ _ t = t fun funcc_conv_args ctxt origin = map (funcc_conv_arg ctxt origin) fun funcc ctxt (origin: term list) f func args = List.foldr (I #-> plus) (case fun_to_time ctxt origin func of SOME t => SOME (build_func (t,funcc_conv_args ctxt origin args)) | NONE => NONE) (map f args) (* Handle case terms *) fun casecIsCase (Type (n1, [_,Type (n2, _)])) = (n1 = "fun" andalso n2 = "fun") | casecIsCase _ = false fun casecLastTyp (Type (n, [T1,T2])) = Type (n, [T1, change_typ T2]) | casecLastTyp _ = error "Internal error: invalid case type" fun casecTyp (Type (n, [T1, T2])) = Type (n, [change_typ T1, (if casecIsCase T2 then casecTyp else casecLastTyp) T2]) | casecTyp _ = error "Internal error: invalid case type" fun casecAbs f (Abs (v,Ta,t)) = (case casecAbs f t of (nconst,t) => (nconst,Abs (v,Ta,t))) | casecAbs f t = (case f t of NONE => (false,HOLogic.zero) | SOME t => (true,t)) fun casecArgs _ [t] = (false, [t]) | casecArgs f (t::ar) = (case casecAbs f t of (nconst, tt) => casecArgs f ar ||> (fn ar => tt :: ar) |>> (if nconst then K true else I)) | casecArgs _ _ = error "Internal error: invalid case term" fun casec _ _ f (Const (t,T)) args = if not (casecIsCase T) then error "Internal error: invalid case type" else let val (nconst, args') = casecArgs f args in plus (f (List.last args)) (if nconst then SOME (build_func (Const (t,casecTyp T), args')) else NONE) end | casec _ _ _ _ _ = error "Internal error: invalid case term" (* Handle if terms -> drop the term if true and false terms are zero *) fun ifc ctxt origin f _ cond tt ft = let fun use_origin _ _ (Free (nm, T as Type ("fun",_))) = HOLogic.mk_fst (Free (nm,HOLogic.mk_prodT (T, change_typ T))) | use_origin _ _ t = t val rcond = map_aterms (use_origin ctxt origin) cond val tt = f tt val ft = f ft in plus (f cond) (case (tt,ft) of (NONE, NONE) => NONE | _ => (SOME ((Const (If_name, @{typ "bool \ nat \ nat \ nat"})) $ rcond $ (opt_term tt) $ (opt_term ft)))) end fun letc_change_typ (Type ("fun", [T1, Type ("fun", [T2, _])])) = (Type ("fun", [T1, Type ("fun", [change_typ T2, HOLogic.natT])])) | letc_change_typ _ = error "Internal error: invalid let type" fun letc _ _ f expT exp nms Ts t = plus (f exp) (if length nms = 0 (* In case of "length nms = 0" the expression got reducted Here we need Bound 0 to gain non-partial application *) then (case f (t $ Bound 0) of SOME (t' $ Bound 0) => (SOME (Const (Let_name, letc_change_typ expT) $ exp $ t')) (* Expression is not used and can therefore let be dropped *) | SOME t' => SOME t' | NONE => NONE) else (case f t of SOME t' => SOME (if Term.is_dependent t' then Const (Let_name, letc_change_typ expT) $ exp $ build_abs t' nms Ts else Term.subst_bounds([exp],t')) | NONE => NONE)) (* The converter for timing functions given to the walker *) val converter : term option converter = { - constc = fn _ => fn _ => fn _ => fn _ => NONE, + constc = fn _ => fn _ => fn _ => fn t => + (case t of Const ("HOL.undefined", _) => SOME (Const ("HOL.undefined", @{typ "nat"})) + | _ => NONE), funcc = funcc, ifc = ifc, casec = casec, letc = letc } fun top_converter is_rec _ _ = opt_term o (fn exp => plus exp (if is_rec then SOME one else NONE)) (* Use converter to convert right side of a term *) fun to_time ctxt origin is_rec term = top_converter is_rec ctxt origin (walk ctxt origin converter term) (* Converts a term to its running time version *) fun convert_term ctxt (origin: term list) is_rec (pT $ (Const (eqN, _) $ l $ r)) = pT $ (Const (eqN, @{typ "nat \ nat \ bool"}) $ (build_func ((walk_func l []) |>> (fun_to_time ctxt origin) |>> Option.valOf ||> conv_args ctxt origin)) $ (to_time ctxt origin is_rec r)) | convert_term _ _ _ _ = error "Internal error: invalid term to convert" (* 4. Tactic to prove "f_dom n" *) fun time_dom_tac ctxt induct_rule domintros = (Induction.induction_tac ctxt true [] [[]] [] (SOME [induct_rule]) [] THEN_ALL_NEW ((K (auto_tac ctxt)) THEN' (fn i => FIRST' ( (if i <= length domintros then [Metis_Tactic.metis_tac [] ATP_Problem_Generate.combsN ctxt [List.nth (domintros, i-1)]] else []) @ [Metis_Tactic.metis_tac [] ATP_Problem_Generate.combsN ctxt domintros]) i))) fun get_terms theory (term: term) = Spec_Rules.retrieve_global theory term |> hd |> #rules |> map Thm.prop_of handle Empty => error "Function or terms of function not found" (* Register timing function of a given function *) fun reg_and_proove_time_func (theory: theory) (term: term list) (terms: term list) print = reg_time_func theory term terms false |> proove_termination term terms print and reg_time_func (theory: theory) (term: term list) (terms: term list) print = let val lthy = Named_Target.theory_init theory val _ = case time_term lthy (hd term) handle (ERROR _) => NONE of SOME _ => error ("Timing function already declared: " ^ (Term.term_name (hd term))) | NONE => () (* 1. Fix all terms *) (* Exchange Var in types and terms to Free and check constraints *) val terms = map (map_aterms fixTerms #> map_types (map_atyps fixTypes) #> fixPartTerms lthy term) terms (* 2. Check if function is recursive *) val is_rec = is_rec lthy term terms (* 3. Convert every equation - Change type of toplevel equation from _ \ _ \ bool to nat \ nat \ bool - On left side change name of function to timing function - Convert right side of equation with conversion schema *) val timing_terms = map (convert_term lthy term is_rec) terms (* 4. Register function and prove termination *) val names = map Term.term_name term val timing_names = map (fun_name_to_time lthy) names val bindings = map (fn nm => (Binding.name nm, NONE, NoSyn)) timing_names fun pat_completeness_auto ctxt = Pat_Completeness.pat_completeness_tac ctxt 1 THEN auto_tac ctxt val specs = map (fn eq => (((Binding.empty, []), eq), [], [])) timing_terms (* For partial functions sequential=true is needed in order to support them We need sequential=false to support the automatic proof of termination over dom *) fun register seq = let val _ = (if seq then warning "Falling back on sequential function..." else ()) val fun_config = Function_Common.FunctionConfig {sequential=seq, default=NONE, domintros=true, partials=false} in Function.add_function bindings specs fun_config pat_completeness_auto lthy end (* Context for printing without showing question marks *) val print_ctxt = lthy |> Config.put show_question_marks false |> Config.put show_sorts false (* Change it for debugging *) val print_ctxt = List.foldl (fn (term, ctxt) => Variable.add_fixes_implicit term ctxt) print_ctxt (term@timing_terms) (* Print result if print *) val _ = if not print then () else let val nms = map (fst o dest_Const) term val typs = map (snd o dest_Const) term in print_timing' print_ctxt { names=nms, terms=terms, typs=typs } { names=timing_names, terms=timing_terms, typs=map change_typ typs } end (* Register function *) val (_, lthy) = register false handle (ERROR _) => register true | Match => register true in Local_Theory.exit_global lthy end and proove_termination (term: term list) terms print (theory: theory) = let val lthy = Named_Target.theory_init theory (* Start proving the termination *) val infos = SOME (map (Function.get_info lthy) term) handle Empty => NONE val timing_names = map (fun_name_to_time lthy o Term.term_name) term (* Proof by lexicographic_order_tac *) val (time_info, lthy') = (Function.prove_termination NONE (Lexicographic_Order.lexicographic_order_tac false lthy) lthy) handle (ERROR _) => let val _ = warning "Falling back on proof over dom..." val _ = (if length term > 1 then error "Proof over dom not supported for mutual recursive functions" else ()) fun args (a$(Var ((nm,_),T))) = args a |> (fn ar => (nm,T)::ar) | args (a$(Const (_,T))) = args a |> (fn ar => ("x",T)::ar) | args _ = [] val dom_args = terms |> hd |> get_l |> args |> Variable.variant_frees lthy [] |> map fst val {inducts, ...} = case infos of SOME [i] => i | _ => error "Proof over dom failed as no induct rule was found" val induct = (Option.valOf inducts |> hd) val domintros = Proof_Context.get_fact lthy (Facts.named (hd timing_names ^ ".domintros")) val prop = (hd timing_names ^ "_dom (" ^ (String.concatWith "," dom_args) ^ ")") |> Syntax.read_prop lthy (* Prove a helper lemma *) val dom_lemma = Goal.prove lthy dom_args [] prop (fn {context, ...} => HEADGOAL (time_dom_tac context induct domintros)) (* Add dom_lemma to simplification set *) val simp_lthy = Simplifier.add_simp dom_lemma lthy in (* Use lemma to prove termination *) Function.prove_termination NONE (auto_tac simp_lthy) lthy end (* Context for printing without showing question marks *) val print_ctxt = lthy' |> Config.put show_question_marks false |> Config.put show_sorts false (* Change it for debugging *) (* Print result if print *) val _ = if not print then () else let val nms = map (fst o dest_Const) term val typs = map (snd o dest_Const) term in print_timing' print_ctxt { names=nms, terms=terms, typs=typs } (info_pfunc time_info) end in (time_info, Local_Theory.exit_global lthy') end fun fix_definition (Const ("Pure.eq", _) $ l $ r) = Const ("HOL.Trueprop", @{typ "bool \ prop"}) $ (Const ("HOL.eq", @{typ "bool \ bool \ bool"}) $ l $ r) | fix_definition t = t fun check_definition [t] = [t] | check_definition _ = error "Only a single defnition is allowed" (* Convert function into its timing function (called by command) *) fun reg_time_fun_cmd (funcs, thms) (theory: theory) = let val ctxt = Proof_Context.init_global theory val fterms = map (Syntax.read_term ctxt) funcs val (_, lthy') = reg_and_proove_time_func theory fterms (case thms of NONE => get_terms theory (hd fterms) | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of) true in lthy' end (* Convert function into its timing function (called by command) with termination proof provided by user*) fun reg_time_function_cmd (funcs, thms) (theory: theory) = let val ctxt = Proof_Context.init_global theory val fterms = map (Syntax.read_term ctxt) funcs val theory = reg_time_func theory fterms (case thms of NONE => get_terms theory (hd fterms) | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of) true in theory end (* Convert function into its timing function (called by command) *) fun reg_time_definition_cmd (funcs, thms) (theory: theory) = let val ctxt = Proof_Context.init_global theory val fterms = map (Syntax.read_term ctxt) funcs val (_, lthy') = reg_and_proove_time_func theory fterms (case thms of NONE => get_terms theory (hd fterms) |> check_definition |> map fix_definition | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of) true in lthy' end val parser = (Scan.repeat1 Parse.prop) -- (Scan.option (Parse.keyword_improper "equations" -- Parse.thms1 >> snd)) val _ = Outer_Syntax.command @{command_keyword "time_fun"} "Defines runtime function of a function" (parser >> (fn p => Toplevel.theory (reg_time_fun_cmd p))) val _ = Outer_Syntax.command @{command_keyword "time_function"} "Defines runtime function of a function" (parser >> (fn p => Toplevel.theory (reg_time_function_cmd p))) val _ = Outer_Syntax.command @{command_keyword "time_definition"} "Defines runtime function of a definition" (parser >> (fn p => Toplevel.theory (reg_time_definition_cmd p))) end diff --git a/src/HOL/Data_Structures/Time_Funs.thy b/src/HOL/Data_Structures/Time_Funs.thy --- a/src/HOL/Data_Structures/Time_Funs.thy +++ b/src/HOL/Data_Structures/Time_Funs.thy @@ -1,60 +1,65 @@ (* File: Data_Structures/Time_Functions.thy Author: Manuel Eberl, TU München *) section \Time functions for various standard library operations\ theory Time_Funs imports Define_Time_Function begin +time_fun "(@)" + +lemma T_append: "T_append xs ys = length xs + 1" +by(induction xs) auto + text \Automatic definition of \T_length\ is cumbersome because of the type class for \size\.\ fun T_length :: "'a list \ nat" where "T_length [] = 1" | "T_length (x # xs) = T_length xs + 1" lemma T_length_eq: "T_length xs = length xs + 1" by (induction xs) auto lemmas [simp del] = T_length.simps fun T_map :: "('a \ nat) \ 'a list \ nat" where "T_map T_f [] = 1" | "T_map T_f (x # xs) = T_f x + T_map T_f xs + 1" lemma T_map_eq: "T_map T_f xs = (\x\xs. T_f x) + length xs + 1" by (induction xs) auto lemmas [simp del] = T_map.simps fun T_filter :: "('a \ nat) \ 'a list \ nat" where "T_filter T_p [] = 1" | "T_filter T_p (x # xs) = T_p x + T_filter T_p xs + 1" lemma T_filter_eq: "T_filter T_p xs = (\x\xs. T_p x) + length xs + 1" by (induction xs) auto lemmas [simp del] = T_filter.simps fun T_nth :: "'a list \ nat \ nat" where "T_nth [] n = 1" | "T_nth (x # xs) n = (case n of 0 \ 1 | Suc n' \ T_nth xs n' + 1)" lemma T_nth_eq: "T_nth xs n = min n (length xs) + 1" by (induction xs n rule: T_nth.induct) (auto split: nat.splits) lemmas [simp del] = T_nth.simps time_fun take time_fun drop lemma T_take_eq: "T_take n xs = min n (length xs) + 1" by (induction xs arbitrary: n) (auto split: nat.splits) lemma T_drop_eq: "T_drop n xs = min n (length xs) + 1" by (induction xs arbitrary: n) (auto split: nat.splits) end