diff --git a/thys/Monad_Memo_DP/example/Bellman_Ford.thy b/thys/Monad_Memo_DP/example/Bellman_Ford.thy --- a/thys/Monad_Memo_DP/example/Bellman_Ford.thy +++ b/thys/Monad_Memo_DP/example/Bellman_Ford.thy @@ -1,1138 +1,1137 @@ subsection \The Bellman-Ford Algorithm\ theory Bellman_Ford imports "HOL-Library.IArray" "HOL-Library.Code_Target_Numeral" "HOL-Library.Product_Lexorder" "HOL-Library.RBT_Mapping" "../heap_monad/Heap_Main" Example_Misc "../util/Tracing" "../util/Ground_Function" begin subsubsection \Misc\ lemma nat_le_cases: fixes n :: nat assumes "i \ n" obtains "i < n" | "i = n" using assms by (cases "i = n") auto context dp_consistency_iterator begin lemma crel_vs_iterate_state: "crel_vs (=) () (iter_state f x)" if "((=) ===>\<^sub>T R) g f" by (metis crel_vs_iterate_state iter_state_iterate_state that) lemma consistent_crel_vs_iterate_state: "crel_vs (=) () (iter_state f x)" if "consistentDP f" using consistentDP_def crel_vs_iterate_state that by simp end instance extended :: (countable) countable proof standard obtain to_nat :: "'a \ nat" where "inj to_nat" by auto let ?f = "\ x. case x of Fin n \ to_nat n + 2 | Pinf \ 0 | Minf \ 1" from \inj _ \ have "inj ?f" by (auto simp: inj_def split: extended.split) then show "\to_nat :: 'a extended \ nat. inj to_nat" by auto qed instance extended :: (heap) heap .. instantiation "extended" :: (conditionally_complete_lattice) complete_lattice begin definition "Inf A = ( if A = {} \ A = {\} then \ else if -\ \ A \ \ bdd_below (Fin -` A) then -\ else Fin (Inf (Fin -` A)))" definition "Sup A = ( if A = {} \ A = {-\} then -\ else if \ \ A \ \ bdd_above (Fin -` A) then \ else Fin (Sup (Fin -` A)))" instance proof standard have [dest]: "Inf (Fin -` A) \ x" if "Fin x \ A" "bdd_below (Fin -` A)" for A and x :: 'a using that by (intro cInf_lower) auto have *: False if "\ z \ Inf (Fin -` A)" "\x. x \ A \ Fin z \ x" "Fin x \ A" for A and x z :: 'a using cInf_greatest[of "Fin -` A" z] that vimage_eq by force show "Inf A \ x" if "x \ A" for x :: "'a extended" and A using that unfolding Inf_extended_def by (cases x) auto show "z \ Inf A" if "\x. x \ A \ z \ x" for z :: "'a extended" and A using that unfolding Inf_extended_def apply (clarsimp; safe) apply force apply force subgoal by (cases z; force simp: bdd_below_def) subgoal by (cases z; force simp: bdd_below_def) subgoal for x y by (cases z; cases y) (auto elim: *) subgoal for x y by (cases z; cases y; simp; metis * less_eq_extended.elims(2)) done have [dest]: "x \ Sup (Fin -` A)" if "Fin x \ A" "bdd_above (Fin -` A)" for A and x :: 'a using that by (intro cSup_upper) auto have *: False if "\ Sup (Fin -` A) \ z" "\x. x \ A \ x \ Fin z" "Fin x \ A" for A and x z :: 'a using cSup_least[of "Fin -` A" z] that vimage_eq by force show "x \ Sup A" if "x \ A" for x :: "'a extended" and A using that unfolding Sup_extended_def by (cases x) auto show "Sup A \ z" if "\x. x \ A \ x \ z" for z :: "'a extended" and A using that unfolding Sup_extended_def apply (clarsimp; safe) apply force apply force subgoal by (cases z; force) subgoal by (cases z; force) subgoal for x y by (cases z; cases y) (auto elim: *) subgoal for x y by (cases z; cases y; simp; metis * extended.exhaust) done show "Inf {} = (top::'a extended)" unfolding Inf_extended_def top_extended_def by simp show "Sup {} = (bot::'a extended)" unfolding Sup_extended_def bot_extended_def by simp qed end instance "extended" :: ("{conditionally_complete_lattice,linorder}") complete_linorder .. lemma Minf_eq_zero[simp]: "-\ = 0 \ False" and Pinf_eq_zero[simp]: "\ = 0 \ False" unfolding zero_extended_def by auto lemma Sup_int: fixes x :: int and X :: "int set" assumes "X \ {}" "bdd_above X" shows "Sup X \ X \ (\y\X. y \ Sup X)" proof - from assms obtain x y where "X \ {..y}" "x \ X" by (auto simp: bdd_above_def) then have *: "finite (X \ {x..y})" "X \ {x..y} \ {}" and "x \ y" by (auto simp: subset_eq) have "\!x\X. (\y\X. y \ x)" proof { fix z assume "z \ X" have "z \ Max (X \ {x..y})" proof cases assume "x \ z" with \z \ X\ \X \ {..y}\ *(1) show ?thesis by (auto intro!: Max_ge) next assume "\ x \ z" then have "z < x" by simp also have "x \ Max (X \ {x..y})" using \x \ X\ *(1) \x \ y\ by (intro Max_ge) auto finally show ?thesis by simp qed } note le = this with Max_in[OF *] show ex: "Max (X \ {x..y}) \ X \ (\z\X. z \ Max (X \ {x..y}))" by auto fix z assume *: "z \ X \ (\y\X. y \ z)" with le have "z \ Max (X \ {x..y})" by auto moreover have "Max (X \ {x..y}) \ z" using * ex by auto ultimately show "z = Max (X \ {x..y})" by auto qed then show "Sup X \ X \ (\y\X. y \ Sup X)" unfolding Sup_int_def by (rule theI') qed lemmas Sup_int_in = Sup_int[THEN conjunct1] lemma Inf_int_in: fixes S :: "int set" assumes "S \ {}" "bdd_below S" shows "Inf S \ S" using assms unfolding Inf_int_def by (smt Sup_int_in bdd_above_uminus image_iff image_is_empty) lemma finite_setcompr_eq_image: "finite {f x |x. P x} \ finite (f ` {x. P x})" by (simp add: setcompr_eq_image) lemma finite_lists_length_le1: "finite {xs. length xs \ i \ set xs \ {0..(n::nat)}}" for i by (auto intro: finite_subset[OF _ finite_lists_length_le[OF finite_atLeastAtMost]]) lemma finite_lists_length_le2: "finite {xs. length xs + 1 \ i \ set xs \ {0..(n::nat)}}" for i by (auto intro: finite_subset[OF _ finite_lists_length_le1[of "i"]]) lemmas [simp] = finite_setcompr_eq_image finite_lists_length_le2[simplified] finite_lists_length_le1 lemma get_return: "run_state (State_Monad.bind State_Monad.get (\ m. State_Monad.return (f m))) m = (f m, m)" by (simp add: State_Monad.bind_def State_Monad.get_def) lemma list_pidgeonhole: assumes "set xs \ S" "card S < length xs" "finite S" obtains as a bs cs where "xs = as @ a # bs @ a # cs" proof - from assms have "\ distinct xs" by (metis card_mono distinct_card not_le) then show ?thesis by (metis append.assoc append_Cons not_distinct_conv_prefix split_list that) qed lemma path_eq_cycleE: assumes "v # ys @ [t] = as @ a # bs @ a # cs" obtains (Nil_Nil) "as = []" "cs = []" "v = a" "a = t" "ys = bs" | (Nil_Cons) cs' where "as = []" "v = a" "ys = bs @ a # cs'" "cs = cs' @ [t]" | (Cons_Nil) as' where "as = v # as'" "cs = []" "a = t" "ys = as' @ a # bs" | (Cons_Cons) as' cs' where "as = v # as'" "cs = cs' @ [t]" "ys = as' @ a # bs @ a # cs'" using assms by (auto simp: Cons_eq_append_conv append_eq_Cons_conv append_eq_append_conv2) lemma le_add_same_cancel1: "a + b \ a \ b \ 0" if "a < \" "-\ < a" for a b :: "int extended" using that by (cases a; cases b) (auto simp add: zero_extended_def) lemma add_gt_minfI: assumes "-\ < a" "-\ < b" shows "-\ < a + b" using assms by (cases a; cases b) auto lemma add_lt_infI: assumes "a < \" "b < \" shows "a + b < \" using assms by (cases a; cases b) auto lemma sum_list_not_infI: "sum_list xs < \" if "\ x \ set xs. x < \" for xs :: "int extended list" using that apply (induction xs) apply (simp add: zero_extended_def)+ by (smt less_extended_simps(2) plus_extended.elims) lemma sum_list_not_minfI: "sum_list xs > -\" if "\ x \ set xs. x > -\" for xs :: "int extended list" using that by (induction xs) (auto intro: add_gt_minfI simp: zero_extended_def) subsubsection \Single-Sink Shortest Path Problem\ datatype bf_result = Path "nat list" int | No_Path | Computation_Error context fixes n :: nat and W :: "nat \ nat \ int extended" begin context fixes t :: nat \ \Final node\ begin text \ The correctness proof closely follows Kleinberg \&\ Tardos: "Algorithm Design", chapter "Dynamic Programming" @{cite "Kleinberg-Tardos"} \ fun weight :: "nat list \ int extended" where "weight [v] = 0" | "weight (v # w # xs) = W v w + weight (w # xs)" definition "OPT i v = ( Min ( {weight (v # xs @ [t]) | xs. length xs + 1 \ i \ set xs \ {0..n}} \ {if t = v then 0 else \} ) )" lemma weight_alt_def': "weight (s # xs) + w = snd (fold (\j (i, x). (j, W i j + x)) xs (s, w))" by (induction xs arbitrary: s w; simp; smt add.commute add.left_commute) lemma weight_alt_def: "weight (s # xs) = snd (fold (\j (i, x). (j, W i j + x)) xs (s, 0))" by (rule weight_alt_def'[of s xs 0, simplified]) lemma weight_append: "weight (xs @ a # ys) = weight (xs @ [a]) + weight (a # ys)" by (induction xs rule: weight.induct; simp add: add.assoc) lemma OPT_0: "OPT 0 v = (if t = v then 0 else \)" unfolding OPT_def by simp subsubsection \Functional Correctness\ lemma OPT_cases: obtains (path) xs where "OPT i v = weight (v # xs @ [t])" "length xs + 1 \ i" "set xs \ {0..n}" | (sink) "v = t" "OPT i v = 0" | (unreachable) "v \ t" "OPT i v = \" unfolding OPT_def using Min_in[of "{weight (v # xs @ [t]) |xs. length xs + 1 \ i \ set xs \ {0..n}} \ {if t = v then 0 else \}"] by (auto simp: finite_lists_length_le2[simplified] split: if_split_asm) lemma OPT_Suc: "OPT (Suc i) v = min (OPT i v) (Min {OPT i w + W v w | w. w \ n})" (is "?lhs = ?rhs") if "t \ n" proof - have "OPT i w + W v w \ OPT (Suc i) v" if "w \ n" for w using OPT_cases[of i w] proof cases case (path xs) with \w \ n\ show ?thesis by (subst OPT_def) (auto intro!: Min_le exI[where x = "w # xs"] simp: add.commute) next case sink then show ?thesis by (subst OPT_def) (auto intro!: Min_le exI[where x = "[]"]) next case unreachable then show ?thesis by simp qed then have "Min {OPT i w + W v w |w. w \ n} \ OPT (Suc i) v" by (auto intro!: Min.boundedI) moreover have "OPT i v \ OPT (Suc i) v" unfolding OPT_def by (rule Min_antimono) auto ultimately have "?lhs \ ?rhs" by simp from OPT_cases[of "Suc i" v] have "?lhs \ ?rhs" proof cases case (path xs) note [simp] = path(1) from path consider (zero) "i = 0" "length xs = 0" | (new) "i > 0" "length xs = i" | (old) "length xs < i" by (cases "length xs = i") auto then show ?thesis proof cases case zero with path have "OPT (Suc i) v = W v t" by simp also have "W v t = OPT i t + W v t" unfolding OPT_def using \i = 0\ by auto also have "\ \ Min {OPT i w + W v w |w. w \ n}" using \t \ n\ by (auto intro: Min_le) finally show ?thesis by (rule min.coboundedI2) next case new with \_ = i\ obtain u ys where [simp]: "xs = u # ys" by (cases xs) auto from path have "OPT i u \ weight (u # ys @ [t])" unfolding OPT_def by (intro Min_le) auto from path have "Min {OPT i w + W v w |w. w \ n} \ W v u + OPT i u" by (intro Min_le) (auto simp: add.commute) also from \OPT i u \ _\ have "\ \ OPT (Suc i) v" by (simp add: add_left_mono) finally show ?thesis by (rule min.coboundedI2) next case old with path have "OPT i v \ OPT (Suc i) v" by (auto 4 3 intro: Min_le simp: OPT_def) then show ?thesis by (rule min.coboundedI1) qed next case unreachable then show ?thesis by simp next case sink then have "OPT i v \ OPT (Suc i) v" unfolding OPT_def by simp then show ?thesis by (rule min.coboundedI1) qed with \?lhs \ ?rhs\ show ?thesis by (rule order.antisym) qed fun bf :: "nat \ nat \ int extended" where "bf 0 v = (if t = v then 0 else \)" | "bf (Suc i) v = min_list (bf i v # [W v w + bf i w . w \ [0 ..< Suc n]])" lemmas [simp del] = bf.simps lemmas bf_simps[simp] = bf.simps[unfolded min_list_fold] lemma bf_correct: "OPT i j = bf i j" if \t \ n\ proof (induction i arbitrary: j) case 0 then show ?case by (simp add: OPT_0) next case (Suc i) have *: "{bf i w + W j w |w. w \ n} = set (map (\w. W j w + bf i w) [0..t \ n\ show ?case by (simp add: OPT_Suc del: upt_Suc, subst Min.set_eq_fold[symmetric], auto simp: *) qed subsubsection \Functional Memoization\ memoize_fun bf\<^sub>m: bf with_memory dp_consistency_mapping monadifies (state) bf.simps text \Generated Definitions\ context includes state_monad_syntax begin thm bf\<^sub>m'.simps bf\<^sub>m_def end text \Correspondence Proof\ memoize_correct by memoize_prover print_theorems lemmas [code] = bf\<^sub>m.memoized_correct interpretation iterator "\ (x, y). x \ n \ y \ n" "\ (x, y). if y < n then (x, y + 1) else (x + 1, 0)" "\ (x, y). x * (n + 1) + y" by (rule table_iterator_up) interpretation bottom_up: dp_consistency_iterator_empty "\ (_::(nat \ nat, int extended) mapping). True" "\ (x, y). bf x y" "\ k. do {m \ State_Monad.get; State_Monad.return (Mapping.lookup m k :: int extended option)}" "\ k v. do {m \ State_Monad.get; State_Monad.set (Mapping.update k v m)}" "\ (x, y). x \ n \ y \ n" "\ (x, y). if y < n then (x, y + 1) else (x + 1, 0)" "\ (x, y). x * (n + 1) + y" Mapping.empty .. definition "iter_bf = iter_state (\ (x, y). bf\<^sub>m' x y)" lemma iter_bf_unfold[code]: "iter_bf = (\ (i, j). (if i \ n \ j \ n then do { bf\<^sub>m' i j; iter_bf (if j < n then (i, j + 1) else (i + 1, 0)) } else State_Monad.return ()))" unfolding iter_bf_def by (rule ext) (safe, clarsimp simp: iter_state_unfold) lemmas bf_memoized = bf\<^sub>m.memoized[OF bf\<^sub>m.crel] lemmas bf_bottom_up = bottom_up.memoized[OF bf\<^sub>m.crel, folded iter_bf_def] text \ This will be our final implementation, which includes detection of negative cycles. See the corresponding section below for the correctness proof. \ definition "bellman_ford \ do { _ \ iter_bf (n, n); xs \ State_Main.map\<^sub>T' (\i. bf\<^sub>m' n i) [0.. State_Main.map\<^sub>T' (\i. bf\<^sub>m' (n + 1) i) [0.. do { _ \ iter_bf (n, n); (\\xs. \\ys. State_Monad.return (if xs = ys then Some xs else None)\ . (State_Main.map\<^sub>T . \\i. bf\<^sub>m' (n + 1) i\ . \[0..)\) . (State_Main.map\<^sub>T . \\i. bf\<^sub>m' n i\ . \[0..) }" unfolding State_Monad_Ext.fun_app_lifted_def bellman_ford_def State_Main.map\<^sub>T_def bind_left_identity . end subsubsection \Imperative Memoization\ context fixes mem :: "nat ref \ nat ref \ int extended option array ref \ int extended option array ref" assumes mem_is_init: "mem = result_of (init_state (n + 1) 1 0) Heap.empty" begin lemma [intro]: "dp_consistency_heap_array_pair' (n + 1) fst snd id 1 0 mem" by (standard; simp add: mem_is_init injective_def) interpretation iterator "\ (x, y). x \ n \ y \ n" "\ (x, y). if y < n then (x, y + 1) else (x + 1, 0)" "\ (x, y). x * (n + 1) + y" by (rule table_iterator_up) lemma [intro]: "dp_consistency_heap_array_pair_iterator (n + 1) fst snd id 1 0 mem (\ (x, y). if y < n then (x, y + 1) else (x + 1, 0)) (\ (x, y). x * (n + 1) + y) (\ (x, y). x \ n \ y \ n)" by (standard; simp add: mem_is_init injective_def) memoize_fun bf\<^sub>h: bf with_memory (default_proof) dp_consistency_heap_array_pair_iterator where size = "n + 1" and key1 = "fst :: nat \ nat \ nat" and key2 = "snd :: nat \ nat \ nat" and k1 = "1 :: nat" and k2 = "0 :: nat" and to_index = "id :: nat \ nat" and mem = mem and cnt = "\ (x, y). x \ n \ y \ n" and nxt = "\ (x :: nat, y). if y < n then (x, y + 1) else (x + 1, 0)" and sizef = "\ (x, y). x * (n + 1) + y" monadifies (heap) bf.simps memoize_correct by memoize_prover lemmas memoized_empty = bf\<^sub>h.memoized_empty[OF bf\<^sub>h.consistent_DP_iter_and_compute[OF bf\<^sub>h.crel]] lemmas iter_heap_unfold = iter_heap_unfold end (* Fixed Memory *) subsubsection \Detecting Negative Cycles\ definition "shortest v = ( Inf ( {weight (v # xs @ [t]) | xs. set xs \ {0..n}} \ {if t = v then 0 else \} ) )" definition "is_path xs \ weight (xs @ [t]) < \" definition "has_negative_cycle \ \xs a ys. set (a # xs @ ys) \ {0..n} \ weight (a # xs @ [a]) < 0 \ is_path (a # ys)" definition "reaches a \ \xs. is_path (a # xs) \ a \ n \ set xs \ {0..n}" lemma fold_sum_aux': assumes "\u \ set (a # xs). \v \ set (xs @ [b]). f v + W u v \ f u" shows "sum_list (map f (a # xs)) \ sum_list (map f (xs @ [b])) + weight (a # xs @ [b])" using assms by (induction xs arbitrary: a; simp) (smt ab_semigroup_add_class.add_ac(1) add.left_commute add_mono) lemma fold_sum_aux: assumes "\u \ set (a # xs). \v \ set (a # xs). f v + W u v \ f u" shows "sum_list (map f (a # xs @ [a])) \ sum_list (map f (a # xs @ [a])) + weight (a # xs @ [a])" using fold_sum_aux'[of a xs a f] assms by auto (metis (no_types, opaque_lifting) add.assoc add.commute add_left_mono) context begin private definition "is_path2 xs \ weight xs < \" private lemma is_path2_remove_cycle: assumes "is_path2 (as @ a # bs @ a # cs)" shows "is_path2 (as @ a # cs)" proof - have "weight (as @ a # bs @ a # cs) = weight (as @ [a]) + weight (a # bs @ [a]) + weight (a # cs)" by (metis Bellman_Ford.weight_append append_Cons append_assoc) with assms have "weight (as @ [a]) < \" "weight (a # cs) < \" unfolding is_path2_def by (simp, metis Pinf_add_right antisym less_extended_simps(4) not_less add.commute)+ then show ?thesis unfolding is_path2_def by (subst weight_append) (rule add_lt_infI) qed private lemma is_path_eq: "is_path xs \ is_path2 (xs @ [t])" unfolding is_path_def is_path2_def .. lemma is_path_remove_cycle: assumes "is_path (as @ a # bs @ a # cs)" shows "is_path (as @ a # cs)" using assms unfolding is_path_eq by (simp add: is_path2_remove_cycle) lemma is_path_remove_cycle2: assumes "is_path (as @ t # cs)" shows "is_path as" using assms unfolding is_path_eq by (simp add: is_path2_remove_cycle) end (* private lemmas *) lemma is_path_shorten: assumes "is_path (i # xs)" "i \ n" "set xs \ {0..n}" "t \ n" "t \ i" obtains xs where "is_path (i # xs)" "i \ n" "set xs \ {0..n}" "length xs < n" proof (cases "length xs < n") case True with assms show ?thesis by (auto intro: that) next case False then have "length xs \ n" by auto with assms(1,3) show ?thesis proof (induction "length xs" arbitrary: xs rule: less_induct) case less then have "length (i # xs @ [t]) > card ({0..n})" by auto moreover from less.prems \i \ n\ \t \ n\ have "set (i # xs @ [t]) \ {0..n}" by auto ultimately obtain a as bs cs where *: "i # xs @ [t] = as @ a # bs @ a # cs" by (elim list_pidgeonhole) auto obtain ys where ys: "is_path (i # ys)" "length ys < length xs" "set (i # ys) \ {0..n}" apply atomize_elim using * proof (cases rule: path_eq_cycleE) case Nil_Nil with \t \ i\ show "\ys. is_path (i # ys) \ length ys < length xs \ set (i # ys) \ {0..n}" by auto next case (Nil_Cons cs') then show "\ys. is_path (i # ys) \ length ys < length xs \ set (i # ys) \ {0..n}" using \set (i # xs @ [t]) \ {0..n}\ \is_path (i # xs)\ is_path_remove_cycle[of "[]"] by - (rule exI[where x = cs'], simp) next case (Cons_Nil as') then show "\ys. is_path (i # ys) \ length ys < length xs \ set (i # ys) \ {0..n}" using \set (i # xs @ [t]) \ {0..n}\ \is_path (i # xs)\ by - (rule exI[where x = as'], auto intro: is_path_remove_cycle2) next case (Cons_Cons as' cs') then show "\ys. is_path (i # ys) \ length ys < length xs \ set (i # ys) \ {0..n}" using \set (i # xs @ [t]) \ {0..n}\ \is_path (i # xs)\ is_path_remove_cycle[of "i # as'"] by - (rule exI[where x = "as' @ a # cs'"], auto) qed then show ?thesis by (cases "n \ length ys") (auto intro: that less) qed qed lemma reaches_non_inf_path: assumes "reaches i" "i \ n" "t \ n" shows "OPT n i < \" proof (cases "t = i") case True with \i \ n\ \t \ n\ have "OPT n i \ 0" unfolding OPT_def by (auto intro: Min_le simp: finite_lists_length_le2[simplified]) then show ?thesis using less_linear by (fastforce simp: zero_extended_def) next case False from assms(1) obtain xs where "is_path (i # xs)" "i \ n" "set xs \ {0..n}" unfolding reaches_def by safe then obtain xs where xs: "is_path (i # xs)" "i \ n" "set xs \ {0..n}" "length xs < n" using \t \ i\ \t \ n\ by (auto intro: is_path_shorten) then have "weight (i # xs @ [t]) < \" unfolding is_path_def by auto with xs(2-) show ?thesis unfolding OPT_def by (elim order.strict_trans1[rotated]) (auto simp: setcompr_eq_image finite_lists_length_le2[simplified]) qed lemma OPT_sink_le_0: "OPT i t \ 0" unfolding OPT_def by (auto simp: finite_lists_length_le2[simplified]) lemma is_path_appendD: assumes "is_path (as @ a # bs)" shows "is_path (a # bs)" using assms weight_append[of as a "bs @ [t]"] unfolding is_path_def by simp (metis Pinf_add_right add.commute less_extended_simps(4) not_less_iff_gr_or_eq) lemma has_negative_cycleI: assumes "set (a # xs @ ys) \ {0..n}" "weight (a # xs @ [a]) < 0" "is_path (a # ys)" shows has_negative_cycle using assms unfolding has_negative_cycle_def by auto lemma OPT_cases2: obtains (path) xs where "v \ t" "OPT i v \ \" "OPT i v = weight (v # xs @ [t])" "length xs + 1 \ i" "set xs \ {0..n}" | (unreachable) "v \ t" "OPT i v = \" | (sink) "v = t" "OPT i v \ 0" unfolding OPT_def using Min_in[of "{weight (v # xs @ [t]) |xs. length xs + 1 \ i \ set xs \ {0..n}} \ {if t = v then 0 else \}"] by (cases "v = t"; force simp: finite_lists_length_le2[simplified] split: if_split_asm) lemma shortest_le_OPT: assumes "v \ n" shows "shortest v \ OPT i v" unfolding OPT_def shortest_def apply (subst Min_Inf) apply (simp add: setcompr_eq_image finite_lists_length_le2[simplified]; fail)+ apply (rule Inf_superset_mono) apply auto done context assumes W_wellformed: "\i \ n. \j \ n. W i j > -\" assumes "t \ n" begin lemma weight_not_minfI: "-\ < weight xs" if "set xs \ {0..n}" "xs \ []" using that using W_wellformed \t \ n\ by (induction xs rule: induct_list012) (auto intro: add_gt_minfI simp: zero_extended_def) lemma OPT_not_minfI: "OPT n i > -\" if "i \ n" proof - have "OPT n i \ {weight (i # xs @ [t]) |xs. length xs + 1 \ n \ set xs \ {0..n}} \ {if t = i then 0 else \}" unfolding OPT_def by (rule Min_in) (auto simp: setcompr_eq_image finite_lists_length_le2[simplified]) with that \t \ n\ show ?thesis by (auto 4 3 intro!: weight_not_minfI simp: zero_extended_def) qed theorem detects_cycle: assumes has_negative_cycle shows "\i \ n. OPT (n + 1) i < OPT n i" proof - from assms \t \ n\ obtain xs a ys where cycle: "a \ n" "set xs \ {0..n}" "set ys \ {0..n}" "weight (a # xs @ [a]) < 0" "is_path (a # ys)" unfolding has_negative_cycle_def by clarsimp then have "reaches a" unfolding reaches_def by auto have reaches: "reaches x" if "x \ set xs" for x proof - from that obtain as bs where "xs = as @ x # bs" by atomize_elim (rule split_list) with cycle have "weight (x # bs @ [a]) < \" using weight_append[of "a # as" x "bs @ [a]"] by simp (metis Pinf_add_right Pinf_le add.commute less_eq_extended.simps(2) not_less) moreover from \reaches a\ obtain cs where "local.weight (a # cs @ [t]) < \" "set cs \ {0..n}" unfolding reaches_def is_path_def by auto ultimately show ?thesis unfolding reaches_def is_path_def using \a \ n\ weight_append[of "x # bs" a "cs @ [t]"] cycle(2) \xs = _\ by - (rule exI[where x = "bs @ [a] @ cs"], auto intro: add_lt_infI) qed let ?S = "sum_list (map (OPT n) (a # xs @ [a]))" obtain u v where "u \ n" "v \ n" "OPT n v + W u v < OPT n u" proof (atomize_elim, rule ccontr) assume "\u v. u \ n \ v \ n \ OPT n v + W u v < OPT n u" then have "?S \ ?S + weight (a # xs @ [a])" using cycle(1-3) by (subst fold_sum_aux; fastforce simp: subset_eq) moreover have "?S > -\" using cycle(1-4) by (intro sum_list_not_minfI, auto intro!: OPT_not_minfI) moreover have "?S < \" using reaches \t \ n\ cycle(1,2) by (intro sum_list_not_infI) (auto intro: reaches_non_inf_path \reaches a\ simp: subset_eq) ultimately have "weight (a # xs @ [a]) \ 0" by (simp add: le_add_same_cancel1) with \weight _ < 0\ show False by simp qed then show ?thesis by - (rule exI[where x = u], auto 4 4 intro: Min.coboundedI min.strict_coboundedI2 elim: order.strict_trans1[rotated] simp: OPT_Suc[OF \t \ n\]) qed corollary bf_detects_cycle: assumes has_negative_cycle shows "\i \ n. bf (n + 1) i < bf n i" using detects_cycle[OF assms] unfolding bf_correct[OF \t \ n\] . lemma shortest_cases: assumes "v \ n" obtains (path) xs where "shortest v = weight (v # xs @ [t])" "set xs \ {0..n}" | (sink) "v = t" "shortest v = 0" | (unreachable) "v \ t" "shortest v = \" | (negative_cycle) "shortest v = -\" "\x. \xs. set xs \ {0..n} \ weight (v # xs @ [t]) < Fin x" proof - let ?S = "{weight (v # xs @ [t]) | xs. set xs \ {0..n}} \ {if t = v then 0 else \}" have "?S \ {}" by auto have Minf_lowest: False if "-\ < a" "-\ = a" for a :: "int extended" using that by auto show ?thesis proof (cases "shortest v") case (Fin x) then have "-\ \ ?S" "bdd_below (Fin -` ?S)" "?S \ {\}" "x = Inf (Fin -` ?S)" unfolding shortest_def Inf_extended_def by (auto split: if_split_asm) from this(1-3) have "x \ Fin -` ?S" unfolding \x = _\ by (intro Inf_int_in, auto simp: zero_extended_def) (smt empty_iff extended.exhaust insertI2 mem_Collect_eq vimage_eq) with \shortest v = _\ show ?thesis unfolding vimage_eq by (auto split: if_split_asm intro: that) next case Pinf with \?S \ {}\ have "t \ v" unfolding shortest_def Inf_extended_def by (auto split: if_split_asm) with \_ = \\ show ?thesis by (auto intro: that) next case Minf then have "?S \ {}" "?S \ {\}" "-\ \ ?S \ \ bdd_below (Fin -` ?S)" unfolding shortest_def Inf_extended_def by (auto split: if_split_asm) from this(3) have "\x. \xs. set xs \ {0..n} \ weight (v # xs @ [t]) < Fin x" proof assume "-\ \ ?S" with weight_not_minfI have False using \v \ n\ \t \ n\ by (auto split: if_split_asm elim: Minf_lowest[rotated]) then show ?thesis .. next assume "\ bdd_below (Fin -` ?S)" show ?thesis proof fix x :: int let ?m = "min x (-1)" from \\ bdd_below _\ obtain m where "Fin m \ ?S" "m < ?m" unfolding bdd_below_def by - (simp, drule spec[of _ "?m"], force) then show "\xs. set xs \ {0..n} \ weight (v # xs @ [t]) < Fin x" by (auto split: if_split_asm simp: zero_extended_def) (metis less_extended_simps(1))+ qed qed with \shortest v = _\ show ?thesis by (auto intro: that) qed qed lemma simple_paths: assumes "\ has_negative_cycle" "weight (v # xs @ [t]) < \" "set xs \ {0..n}" "v \ n" obtains ys where "weight (v # ys @ [t]) \ weight (v # xs @ [t])" "set ys \ {0..n}" "length ys < n" | "v = t" using assms(2-) proof (atomize_elim, induction "length xs" arbitrary: xs rule: less_induct) case (less ys) note ys = less.prems(1,2) note IH = less.hyps have path: "is_path (v # ys)" using is_path_def not_less_iff_gr_or_eq ys(1) by fastforce show ?case proof (cases "length ys \ n") case True with ys \v \ n\ \t \ n\ obtain a as bs cs where "v # ys @ [t] = as @ a # bs @ a # cs" by - (rule list_pidgeonhole[of "v # ys @ [t]" "{0..n}"], auto) then show ?thesis proof (cases rule: path_eq_cycleE) case Nil_Nil then show ?thesis by simp next case (Nil_Cons cs') then have *: "weight (v # ys @ [t]) = weight (a # bs @ [a]) + weight (a # cs' @ [t])" by (simp add: weight_append[of "a # bs" a "cs' @ [t]", simplified]) show ?thesis proof (cases "weight (a # bs @ [a]) < 0") case True with Nil_Cons \set ys \ _\ path show ?thesis using assms(1) by (force intro: has_negative_cycleI[of a bs ys]) next case False then have "weight (a # bs @ [a]) \ 0" by auto with * ys have "weight (a # cs' @ [t]) \ weight (v # ys @ [t])" using add_mono not_le by fastforce with Nil_Cons \length ys \ n\ ys show ?thesis using IH[of cs'] by simp (meson le_less_trans order_trans) qed next case (Cons_Nil as') with ys have *: "weight (v # ys @ [t]) = weight (v # as' @ [t]) + weight (a # bs @ [a])" using weight_append[of "v # as'" t "bs @ [t]"] by simp show ?thesis proof (cases "weight (a # bs @ [a]) < 0") case True with Cons_Nil \set ys \ _\ path assms(1) show ?thesis using is_path_appendD[of "v # as'"] by (force intro: has_negative_cycleI[of a bs bs]) next case False then have "weight (a # bs @ [a]) \ 0" by auto with * ys(1) have "weight (v # as' @ [t]) \ weight (v # ys @ [t])" using add_left_mono by fastforce with Cons_Nil \length ys \ n\ \v \ n\ ys show ?thesis using IH[of as'] by simp (meson le_less_trans order_trans) qed next case (Cons_Cons as' cs') with ys have *: "weight (v # ys @ [t]) = weight (v # as' @ a # cs' @ [t]) + weight (a # bs @ [a])" using weight_append[of "v # as'" a "bs @ a # cs' @ [t]"] weight_append[of "a # bs" a "cs' @ [t]"] weight_append[of "v # as'" a "cs' @ [t]"] by (simp add: algebra_simps) show ?thesis proof (cases "weight (a # bs @ [a]) < 0") case True with Cons_Cons \set ys \ _\ path assms(1) show ?thesis using is_path_appendD[of "v # as'"] by (force intro: has_negative_cycleI[of a bs "bs @ a # cs'"]) next case False then have "weight (a # bs @ [a]) \ 0" by auto with * ys have "weight (v # as' @ a # cs' @ [t]) \ weight (v # ys @ [t])" using add_left_mono by fastforce with Cons_Cons \v \ n\ ys show ?thesis using is_path_remove_cycle2 IH[of "as' @ a # cs'"] by simp (meson le_less_trans order_trans) qed qed next case False with \set ys \ _\ show ?thesis by auto qed qed theorem shorter_than_OPT_n_has_negative_cycle: assumes "shortest v < OPT n v" "v \ n" shows has_negative_cycle proof - from assms obtain ys where ys: "weight (v # ys @ [t]) < OPT n v" "set ys \ {0..n}" apply (cases rule: OPT_cases2[of v n]; cases rule: shortest_cases[OF \v \ n\]; simp) apply (metis uminus_extended.cases) using less_extended_simps(2) less_trans apply blast apply (metis less_eq_extended.elims(2) less_extended_def zero_extended_def) done show ?thesis proof (cases "v = t") case True with ys \t \ n\ show ?thesis using OPT_sink_le_0[of n] unfolding has_negative_cycle_def is_path_def using less_extended_def by force next case False show ?thesis proof (rule ccontr) assume "\ has_negative_cycle" with False False ys \v \ n\ obtain xs where "weight (v # xs @ [t]) \ weight (v # ys @ [t])" "set xs \ {0..n}" "length xs < n" using less_extended_def by (fastforce elim!: simple_paths[of v ys]) then have "OPT n v \ weight (v # xs @ [t])" unfolding OPT_def by (intro Min_le) auto with \_ \ weight (v # ys @ [t])\ \weight (v # ys @ [t]) < OPT n v\ show False by simp qed qed qed corollary detects_cycle_has_negative_cycle: assumes "OPT (n + 1) v < OPT n v" "v \ n" shows has_negative_cycle using assms shortest_le_OPT[of v "n + 1"] shorter_than_OPT_n_has_negative_cycle[of v] by auto corollary bellman_ford_detects_cycle: "has_negative_cycle \ (\v \ n. OPT (n + 1) v < OPT n v)" using detects_cycle_has_negative_cycle detects_cycle by blast corollary bellman_ford_shortest_paths: assumes "\ has_negative_cycle" shows "\v \ n. bf n v = shortest v" proof - have "OPT n v \ shortest v" if "v \ n" for v using that assms shorter_than_OPT_n_has_negative_cycle[of v] by force then show ?thesis unfolding bf_correct[OF \t \ n\, symmetric] by (safe, rule order.antisym) (auto elim: shortest_le_OPT) qed lemma OPT_mono: "OPT m v \ OPT n v" if \v \ n\ \n \ m\ using that unfolding OPT_def by (intro Min_antimono) auto corollary bf_fix: assumes "\ has_negative_cycle" "m \ n" shows "\v \ n. bf m v = bf n v" proof (intro allI impI) fix v assume "v \ n" from \v \ n\ \n \ m\ have "shortest v \ OPT m v" by (simp add: shortest_le_OPT) moreover from \v \ n\ \n \ m\ have "OPT m v \ OPT n v" by (rule OPT_mono) moreover from \v \ n\ assms have "OPT n v \ shortest v" using shorter_than_OPT_n_has_negative_cycle[of v] by force ultimately show "bf m v = bf n v" unfolding bf_correct[OF \t \ n\, symmetric] by simp qed lemma bellman_ford_correct': "bf\<^sub>m.crel_vs (=) (if has_negative_cycle then None else Some (map shortest [0..m' = bf\<^sub>m.crel[unfolded bf\<^sub>m.consistentDP_def, THEN rel_funD, of "(m, x)" "(m, y)" for m x y, unfolded prod.case] have "?l = ?r" supply [simp del] = bf_simps supply [simp add] = bf_fix[rule_format, symmetric] bellman_ford_shortest_paths[rule_format, symmetric] unfolding Wrap_def App_def using bf_detects_cycle by (fastforce elim: nat_le_cases) \ \Slightly transform the goal, then apply parametric reasoning like usual.\ show ?thesis \ \Roughly \ unfolding bellman_ford_alt_def \?l = ?r\ \ \Obtain parametric form.\ apply (rule bf\<^sub>m.crel_vs_bind_ignore[rotated]) \ \Drop bind.\ apply (rule bottom_up.consistent_crel_vs_iterate_state[OF bf\<^sub>m.crel, folded iter_bf_def]) apply (subst Transfer.Rel_def[symmetric]) \ \Setup typical goal for automated reasoner.\ \ \We need to reason manually because we are not in the context where \bf\<^sub>m\ was defined.\ \ \This is roughly what @{method "memoize_prover_match_step"}/\Transform_Tactic.step_tac\ does.\ - ML_prf \val ctxt = @{context}\ - apply (tactic \Transform_Tactic.solve_relator_tac ctxt 1\ + apply (tactic \Transform_Tactic.solve_relator_tac \<^context> 1\ | rule HOL.refl | rule bf\<^sub>m.dp_match_rule | rule bf\<^sub>m.crel_vs_return_ext | (subst Rel_def, rule crel_bf\<^sub>m') - | tactic \Transform_Tactic.transfer_raw_tac ctxt 1\)+ + | tactic \Transform_Tactic.transfer_raw_tac \<^context> 1\)+ done qed theorem bellman_ford_correct: "fst (run_state bellman_ford Mapping.empty) = (if has_negative_cycle then None else Some (map shortest [0..m.cmem_empty bellman_ford_correct'[unfolded bf\<^sub>m.crel_vs_def, rule_format, of Mapping.empty] unfolding bf\<^sub>m.crel_vs_def by auto end (* Wellformedness *) end (* Final Node *) end (* Bellman Ford *) subsubsection \Extracting an Executable Constant for the Imperative Implementation\ ground_function (prove_termination) bf\<^sub>h'_impl: bf\<^sub>h'.simps lemma bf\<^sub>h'_impl_def: fixes n :: nat fixes mem :: "nat ref \ nat ref \ int extended option array ref \ int extended option array ref" assumes mem_is_init: "mem = result_of (init_state (n + 1) 1 0) Heap.empty" shows "bf\<^sub>h'_impl n w t mem = bf\<^sub>h' n w t mem" proof - have "bf\<^sub>h'_impl n w t mem i j = bf\<^sub>h' n w t mem i j" for i j by (induction rule: bf\<^sub>h'.induct[OF mem_is_init]; simp add: bf\<^sub>h'.simps[OF mem_is_init]; solve_cong simp ) then show ?thesis by auto qed definition "iter_bf_heap n w t mem = iterator_defs.iter_heap (\(x, y). x \ n \ y \ n) (\(x, y). if y < n then (x, y + 1) else (x + 1, 0)) (\(x, y). bf\<^sub>h'_impl n w t mem x y)" lemma iter_bf_heap_unfold[code]: "iter_bf_heap n w t mem = (\ (i, j). (if i \ n \ j \ n then do { bf\<^sub>h'_impl n w t mem i j; iter_bf_heap n w t mem (if j < n then (i, j + 1) else (i + 1, 0)) } else Heap_Monad.return ()))" unfolding iter_bf_heap_def by (rule ext) (safe, simp add: iter_heap_unfold) definition "bf_impl n w t i j = do { mem \ (init_state (n + 1) (1::nat) (0::nat) :: (nat ref \ nat ref \ int extended option array ref \ int extended option array ref) Heap); iter_bf_heap n w t mem (0, 0); bf\<^sub>h'_impl n w t mem i j }" lemma bf_impl_correct: "bf n w t i j = result_of (bf_impl n w t i j) Heap.empty" using memoized_empty[OF HOL.refl, of n w t "(i, j)"] by (simp add: execute_bind_success[OF succes_init_state] bf_impl_def bf\<^sub>h'_impl_def iter_bf_heap_def ) subsubsection \Test Cases\ definition "G\<^sub>1_list = [[(1 :: nat,-6 :: int), (2,4), (3,5)], [(3,10)], [(3,2)], []]" definition "G\<^sub>2_list = [[(1 :: nat,-6 :: int), (2,4), (3,5)], [(3,10)], [(3,2)], [(0, -5)]]" definition "G\<^sub>3_list = [[(1 :: nat,-1 :: int), (2,2)], [(2,5), (3,4)], [(3,2), (4,3)], [(2,-2), (4,2)], []]" definition "G\<^sub>4_list = [[(1 :: nat,-1 :: int), (2,2)], [(2,5), (3,4)], [(3,2), (4,3)], [(2,-3), (4,2)], []]" definition "graph_of a i j = case_option \ (Fin o snd) (List.find (\ p. fst p = j) (a !! i))" definition "test_bf = bf_impl 3 (graph_of (IArray G\<^sub>1_list)) 3 3 0" code_reflect Test functions test_bf text \One can see a trace of the calls to the memory in the output\ ML \Test.test_bf ()\ lemma bottom_up_alt[code]: "bf n W t i j = fst (run_state (iter_bf n W t (0, 0) \ (\_. bf\<^sub>m' n W t i j)) Mapping.empty)" using bf_bottom_up by auto definition "bf_ia n W t i j = (let W' = graph_of (IArray W) in fst (run_state (iter_bf n W' t (i, j) \ (\_. bf\<^sub>m' n W' t i j)) Mapping.empty) )" \ \Component tests.\ lemma "fst (run_state (bf\<^sub>m' 3 (graph_of (IArray G\<^sub>1_list)) 3 3 0) Mapping.empty) = 4" "bf 3 (graph_of (IArray G\<^sub>1_list)) 3 3 0 = 4" by eval+ \ \Regular test cases.\ lemma "fst (run_state (bellman_ford 3 (graph_of (IArray G\<^sub>1_list)) 3) Mapping.empty) = Some [4, 10, 2, 0]" "fst (run_state (bellman_ford 4 (graph_of (IArray G\<^sub>3_list)) 4) Mapping.empty) = Some [4, 5, 3, 1, 0]" by eval+ \ \Test detection of negative cycles.\ lemma "fst (run_state (bellman_ford 3 (graph_of (IArray G\<^sub>2_list)) 3) Mapping.empty) = None" "fst (run_state (bellman_ford 4 (graph_of (IArray G\<^sub>4_list)) 4) Mapping.empty) = None" by eval+ end (* Theory *) \ No newline at end of file diff --git a/thys/Monad_Memo_DP/transform/Transform.ML b/thys/Monad_Memo_DP/transform/Transform.ML --- a/thys/Monad_Memo_DP/transform/Transform.ML +++ b/thys/Monad_Memo_DP/transform/Transform.ML @@ -1,238 +1,237 @@ -structure Transform_DP = struct +signature TRANSFORM_DP = +sig + val dp_fun_part1_cmd: + (binding * string) * ((bool * (xstring * Position.T)) * (string * string) list) option + -> local_theory -> local_theory + val dp_fun_part2_cmd: string * (Facts.ref * Token.src list) list -> local_theory -> local_theory + val dp_correct_cmd: local_theory -> Proof.state +end + +structure Transform_DP : TRANSFORM_DP = +struct fun dp_interpretation standard_proof locale_name instance qualifier dp_term lthy = lthy |> Interpretation.isar_interpretation ([(locale_name, ((qualifier, true), (Expression.Named (("dp", dp_term) :: instance), [])))], []) |> (if standard_proof then Proof.global_default_proof else Proof.global_immediate_proof) -fun prep_params (((scope, tm_str), def_thms_opt), mem_locale_opt) lthy = +fun prep_params (((scope, tm_str), def_thms_opt), mem_locale_opt) ctxt = let - val tm = Syntax.read_term lthy tm_str + val tm = Syntax.read_term ctxt tm_str val scope' = (Binding.is_empty scope? Binding.map_name (fn _ => Transform_Misc.term_name tm ^ "\<^sub>T")) scope - val def_thms_opt' = Option.map (Attrib.eval_thms lthy) def_thms_opt - val mem_locale_opt' = Option.map (Locale.check (Proof_Context.theory_of lthy)) mem_locale_opt + val def_thms_opt' = Option.map (Attrib.eval_thms ctxt) def_thms_opt + val mem_locale_opt' = Option.map (Locale.check (Proof_Context.theory_of ctxt)) mem_locale_opt in (scope', tm, def_thms_opt', mem_locale_opt') end (* fun dp_interpretation_cmd args lthy = let val (scope, tm, _, mem_locale_opt) = prep_params args lthy val scope_name = Binding.name_of scope in case mem_locale_opt of NONE => lthy | SOME x => dp_interpretation x scope_name (Transform_Misc.uncurry tm) lthy end *) fun do_monadify heap_name scope tm mem_locale_opt def_thms_opt lthy = let val monad_consts = Transform_Const.get_monad_const heap_name val scope_name = Binding.name_of scope val memoizer_opt = if is_none mem_locale_opt then NONE else SOME (Transform_Misc.locale_term lthy scope_name "checkmem") val old_info_opt = Function_Common.import_function_data tm lthy val old_defs_opt = [ K def_thms_opt, K (Option.mapPartial #simps old_info_opt) ] |> Library.get_first (fn x => x ()) val old_defs = case old_defs_opt of SOME defs => defs | NONE => raise TERM("no definition", [tm]) val ((_, old_defs_imported), _) = Variable.import true old_defs lthy (* val new_bind = Binding.suffix_name "\<^sub>T'" scope val new_bindT = Binding.suffix_name "\<^sub>T" scope *) val new_bind = Binding.suffix_name "'" scope val new_bindT = scope fun dest_def (def, def_imported) = let val def_imported_meta = def_imported |> Local_Defs.meta_rewrite_rule lthy val eqs = def_imported_meta |> Thm.full_prop_of val (head, _) = Logic.dest_equals eqs |> fst |> Transform_Misc.behead tm (*val _ = if Term.aconv_untyped (head, tm) then () else raise THM("invalid definition", 0, [def])*) val Abs t = Term.lambda_name (Binding.name_of new_bind, head) eqs val (t_name, eqs') = Term.dest_abs t val _ = @{assert} (t_name = Binding.name_of new_bind) (*val eqs' = Term.subst_atomic [(head, Free (Binding.name_of new_bind, fastype_of head))] eqs*) val (rhs_conv, eqsT, n_args) = Transform_Term.lift_equation monad_consts lthy (Logic.dest_equals eqs') memoizer_opt val def_meta' = def |> Local_Defs.meta_rewrite_rule lthy |> Conv.fconv_rule (Conv.arg_conv (rhs_conv lthy)) val def_meta_simped = def_meta' |> Conv.fconv_rule ( - repeat_sweep_conv (K Transform_Term.rewrite_pureapp_beta_conv) lthy + Transform_Term.repeat_sweep_conv (K Transform_Term.rewrite_pureapp_beta_conv) lthy ) (* val eqsT_simped = eqsT |> Syntax.check_term lthy |> Thm.cterm_of lthy - |> repeat_sweep_conv (K Transform_Term.rewrite_app_beta_conv) lthy + |> Transform_Term.repeat_sweep_conv (K Transform_Term.rewrite_app_beta_conv) lthy |> Thm.full_prop_of |> Logic.dest_equals |> snd *) in ((def_meta_simped, eqsT), n_args) end val ((old_defs', new_defs_raw), n_args) = map dest_def (old_defs ~~ old_defs_imported) |> split_list |>> split_list ||> Transform_Misc.the_element val new_defs = Syntax.check_props lthy new_defs_raw |> map (fn eqsT => eqsT |> Thm.cterm_of lthy - |> repeat_sweep_conv (K (#rewrite_app_beta_conv monad_consts)) lthy + |> Transform_Term.repeat_sweep_conv (K (#rewrite_app_beta_conv monad_consts)) lthy |> Thm.full_prop_of |> Logic.dest_equals |> snd) (*val _ = map (Pretty.writeln o Syntax.pretty_term @{context} o Thm.full_prop_of) old_defs'*) (*val (new_defs, lthy) = Variable.importT_terms new_defs lthy*) - val (new_info, lthy) = Transform_Misc.add_function new_bind new_defs lthy + val (new_info, lthy1) = Transform_Misc.add_function new_bind new_defs lthy val replay_tac = case old_info_opt of NONE => no_tac - | SOME info => Transform_Tactic.totality_replay_tac info new_info lthy + | SOME info => Transform_Tactic.totality_replay_tac info new_info lthy1 val totality_tac = replay_tac - ORELSE (Function_Common.termination_prover_tac false lthy + ORELSE (Function_Common.termination_prover_tac false lthy1 THEN Transform_Tactic.my_print_tac "termination by default prover") - val (new_info, lthy) = Function.prove_termination NONE totality_tac lthy + val (new_info, lthy2) = Function.prove_termination NONE totality_tac lthy1 val new_def' = new_info |> #simps |> the val head' = new_info |> #fs |> the_single - val headT = Transform_Term.wrap_head monad_consts head' n_args |> Syntax.check_term lthy - val ((headTC, (_, new_defT)), lthy) = Local_Theory.define ((new_bindT, NoSyn), ((Thm.def_binding new_bindT,[]), headT)) lthy + val headT = Transform_Term.wrap_head monad_consts head' n_args |> Syntax.check_term lthy2 + val ((headTC, (_, new_defT)), lthy) = Local_Theory.define ((new_bindT, NoSyn), ((Thm.def_binding new_bindT,[]), headT)) lthy2 - val lthy = Transform_Data.commit_dp_info (#monad_name monad_consts) ({ + val lthy3 = Transform_Data.commit_dp_info (#monad_name monad_consts) ({ old_head = tm, new_head' = head', new_headT = headTC, old_defs = old_defs', new_def' = new_def', new_defT = new_defT }) lthy - val _ = Proof_Display.print_consts true (Position.thread_data ()) lthy (K false) [ + val _ = Proof_Display.print_consts true (Position.thread_data ()) lthy3 (K false) [ (Binding.name_of new_bind, Term.type_of head'), (Binding.name_of new_bindT, Term.type_of headTC)] - in lthy end + in lthy3 end fun gen_dp_monadify prep_term args lthy = let val (scope, tm, def_thms_opt, mem_locale_opt) = prep_params args lthy (* val memoizer_opt = memoizer_scope_opt |> Option.map (fn memoizer_scope => Syntax.read_term lthy (Long_Name.qualify memoizer_scope Transform_Const.checkmemVN)) val _ = memoizer_opt |> Option.map (fn memoizer => if Term.aconv_untyped (head_of memoizer, @{term mem_defs.checkmem}) then () else raise TERM("invalid memoizer", [the memoizer_opt])) *) in do_monadify "state" scope tm mem_locale_opt def_thms_opt lthy end val dp_monadify_cmd = gen_dp_monadify Syntax.read_term fun dp_fun_part1_cmd ((scope, tm_str), (mem_locale_instance_opt)) lthy = let val scope_name = Binding.name_of scope val tm = Syntax.read_term lthy tm_str val _ = if is_Free tm then warning ("Free term: " ^ (Syntax.pretty_term lthy tm |> Pretty.string_of)) else () val mem_locale_opt' = Option.map (Locale.check (Proof_Context.theory_of lthy) o (snd o fst)) mem_locale_instance_opt - val lthy_f = case mem_locale_instance_opt of - NONE => I + val lthy1 = case mem_locale_instance_opt of + NONE => lthy | SOME ((standard_proof, locale_name), instance) => let val locale_name = Locale.check (Proof_Context.theory_of lthy) locale_name val instance = map (apsnd (Syntax.read_term lthy)) instance in - dp_interpretation standard_proof locale_name instance scope_name (Transform_Misc.uncurry tm) + dp_interpretation standard_proof locale_name instance scope_name (Transform_Misc.uncurry tm) lthy end - - val lthy = lthy_f lthy - val lthy = Transform_Data.add_tmp_cmd_info (Binding.reset_pos scope, tm, mem_locale_opt') lthy - in - lthy + Transform_Data.add_tmp_cmd_info (Binding.reset_pos scope, tm, mem_locale_opt') lthy1 end fun dp_fun_part2_cmd (heap_name, def_thms_str) lthy = let val {scope, head=tm, locale=locale_opt, dp_info=dp_info_opt} = Transform_Data.get_last_cmd_info lthy val _ = if is_none dp_info_opt then () else raise TERM("already monadified", [tm]) val def_thms = Attrib.eval_thms lthy def_thms_str - val heap_typ = Syntax.read_typ - val lthy = do_monadify heap_name scope tm locale_opt (SOME def_thms) lthy in - lthy + do_monadify heap_name scope tm locale_opt (SOME def_thms) lthy end fun dp_correct_cmd lthy = let val {scope, head=tm, locale=locale_opt, dp_info=dp_info_opt} = Transform_Data.get_last_cmd_info lthy val dp_info = case dp_info_opt of SOME x => x | NONE => raise TERM("not yet monadified", [tm]) val _ = if is_some locale_opt then () else raise TERM("not interpreted yet", [tm]) val scope_name = Binding.name_of scope val consistentDP = Transform_Misc.locale_term lthy scope_name "consistentDP" val dpT' = #new_head' dp_info val dpT'_curried = dpT' |> Transform_Misc.uncurry val goal_pat = consistentDP $ dpT'_curried val goal_prop = Syntax.check_term lthy (HOLogic.mk_Trueprop goal_pat) val tuple_pat = type_of dpT' |> strip_type |> fst |> length |> Name.invent_list [] "a" |> map (fn s => Var ((s, 0), TVar ((s, 0), @{sort type}))) |> HOLogic.mk_tuple |> Thm.cterm_of lthy val memoized_thm_opt = Transform_Misc.locale_thms lthy scope_name "memoized" |> the_single |> SOME handle ERROR msg => (warning msg; NONE) val memoized_thm'_opt = memoized_thm_opt |> Option.map (Drule.infer_instantiate' lthy [NONE, SOME tuple_pat]) fun display_thms thm_binds ctxt = - Proof_Display.print_results true (Position.thread_data ()) ctxt((Thm.theoremK, ""), [thm_binds]) + Proof_Display.print_results true (Position.thread_data ()) ctxt ((Thm.theoremK, ""), [thm_binds]) val crel_thm_name = "crel" val memoized_thm_name = "memoized_correct" - fun afterqed thmss ctxt = + fun afterqed result lthy1 = let - val [[crel_thm]] = thmss - - val (crel_thm_binds, ctxt) = Local_Theory.note ( - (Binding.qualify_name true scope crel_thm_name, []), - [crel_thm] - ) ctxt - - val _ = display_thms crel_thm_binds ctxt + val [[crel_thm]] = result - val ctxt = case memoized_thm'_opt of NONE => ctxt | SOME memoized_thm' => let - val (memoized_thm_binds, ctxt) = Local_Theory.note ( - (Binding.qualify_name true scope memoized_thm_name, []), - [(crel_thm RS memoized_thm') |> Local_Defs.unfold lthy @{thms prod.case}] - ) ctxt + val (crel_thm_binds, lthy2) = lthy1 + |> Local_Theory.note ((Binding.qualify_name true scope crel_thm_name, []), [crel_thm]) - val _ = display_thms memoized_thm_binds ctxt - in ctxt end + val _ = display_thms crel_thm_binds lthy2 in - ctxt + case memoized_thm'_opt of + NONE => lthy2 + | SOME memoized_thm' => + let + val (memoized_thm_binds, lthy3) = lthy2 + |> Local_Theory.note + ((Binding.qualify_name true scope memoized_thm_name, []), + [(crel_thm RS memoized_thm') |> Local_Defs.unfold lthy @{thms prod.case}]) + val _ = display_thms memoized_thm_binds lthy3 + in lthy3 end end - - val goal = Proof.theorem NONE afterqed [[(goal_prop, [])]] lthy in - goal + Proof.theorem NONE afterqed [[(goal_prop, [])]] lthy end - end diff --git a/thys/Monad_Memo_DP/transform/Transform_Cmd.thy b/thys/Monad_Memo_DP/transform/Transform_Cmd.thy --- a/thys/Monad_Memo_DP/transform/Transform_Cmd.thy +++ b/thys/Monad_Memo_DP/transform/Transform_Cmd.thy @@ -1,75 +1,76 @@ subsection \Tool Setup\ theory Transform_Cmd imports "../Pure_Monad" "../state_monad/DP_CRelVS" "../heap_monad/DP_CRelVH" keywords "memoize_fun" :: thy_decl and "monadifies" :: thy_decl and "memoize_correct" :: thy_goal and "with_memory" :: quasi_command and "default_proof" :: quasi_command begin ML_file \../transform/Transform_Misc.ML\ ML_file \../transform/Transform_Const.ML\ ML_file \../transform/Transform_Data.ML\ ML_file \../transform/Transform_Tactic.ML\ ML_file \../transform/Transform_Term.ML\ ML_file \../transform/Transform.ML\ ML_file \../transform/Transform_Parser.ML\ ML \ val _ = Outer_Syntax.local_theory @{command_keyword memoize_fun} "whatever" (Transform_Parser.dp_fun_part1_parser >> Transform_DP.dp_fun_part1_cmd) val _ = Outer_Syntax.local_theory @{command_keyword monadifies} "whatever" (Transform_Parser.dp_fun_part2_parser >> Transform_DP.dp_fun_part2_cmd) -\ -ML \ val _ = Outer_Syntax.local_theory_to_proof @{command_keyword memoize_correct} "whatever" (Scan.succeed Transform_DP.dp_correct_cmd) \ method_setup memoize_prover = \ -Scan.succeed (fn ctxt => (SIMPLE_METHOD' ( - Transform_Data.get_last_cmd_info ctxt - |> Transform_Tactic.solve_consistentDP_tac ctxt))) + Scan.succeed (fn ctxt => SIMPLE_METHOD' ( + Transform_Data.get_last_cmd_info ctxt + |> Transform_Tactic.solve_consistentDP_tac ctxt)) \ method_setup memoize_prover_init = \ -Scan.succeed (fn ctxt => (SIMPLE_METHOD' ( - Transform_Data.get_last_cmd_info ctxt - |> Transform_Tactic.prepare_consistentDP_tac ctxt))) + Scan.succeed (fn ctxt => SIMPLE_METHOD' ( + Transform_Data.get_last_cmd_info ctxt + |> Transform_Tactic.prepare_consistentDP_tac ctxt)) \ method_setup memoize_prover_case_init = \ -Scan.succeed (fn ctxt => (SIMPLE_METHOD' ( - Transform_Data.get_last_cmd_info ctxt - |> Transform_Tactic.prepare_case_tac ctxt))) + Scan.succeed (fn ctxt => SIMPLE_METHOD' ( + Transform_Data.get_last_cmd_info ctxt + |> Transform_Tactic.prepare_case_tac ctxt)) \ method_setup memoize_prover_match_step = \ -Scan.succeed (fn ctxt => (SIMPLE_METHOD' ( - Transform_Data.get_last_cmd_info ctxt - |> Transform_Tactic.step_tac ctxt))) + Scan.succeed (fn ctxt => SIMPLE_METHOD' ( + Transform_Data.get_last_cmd_info ctxt + |> Transform_Tactic.step_tac ctxt)) \ method_setup memoize_unfold_defs = \ -Scan.option (Scan.lift (Args.parens Args.name) -- Args.term) >> (fn tm_opt => fn ctxt => (SIMPLE_METHOD' ( - Transform_Data.get_or_last_cmd_info ctxt tm_opt - |> Transform_Tactic.dp_unfold_defs_tac ctxt))) + Scan.option (Scan.lift (Args.parens Args.name) -- Args.term) >> + (fn tm_opt => fn ctxt => SIMPLE_METHOD' + (Transform_Data.get_or_last_cmd_info ctxt tm_opt + |> Transform_Tactic.dp_unfold_defs_tac ctxt)) \ method_setup memoize_combinator_init = \ -Scan.option (Scan.lift (Args.parens Args.name) -- Args.term) >> (fn tm_opt => fn ctxt => (SIMPLE_METHOD' ( - Transform_Data.get_or_last_cmd_info ctxt tm_opt - |> Transform_Tactic.prepare_combinator_tac ctxt))) + Scan.option (Scan.lift (Args.parens Args.name) -- Args.term) >> + (fn tm_opt => fn ctxt => SIMPLE_METHOD' + (Transform_Data.get_or_last_cmd_info ctxt tm_opt + |> Transform_Tactic.prepare_combinator_tac ctxt)) \ -end (* theory *) + +end diff --git a/thys/Monad_Memo_DP/transform/Transform_Const.ML b/thys/Monad_Memo_DP/transform/Transform_Const.ML --- a/thys/Monad_Memo_DP/transform/Transform_Const.ML +++ b/thys/Monad_Memo_DP/transform/Transform_Const.ML @@ -1,98 +1,112 @@ -structure Transform_Const = struct +signature TRANSFORM_CONST = +sig + type MONAD_CONSTS = { + monad_name: string, + mk_stateT: typ -> typ, + return: term -> term, + app: (term * term) -> term, + if_termN: string, + checkmemVN: string, + rewrite_app_beta_conv: conv + } + val get_monad_const: string -> MONAD_CONSTS +end + +structure Transform_Const : TRANSFORM_CONST = +struct val pureappN = @{const_name Pure_Monad.App} fun pureapp tm = Const (pureappN, dummyT) $ tm type MONAD_CONSTS = { monad_name: string, mk_stateT: typ -> typ, return: term -> term, app: (term * term) -> term, if_termN: string, checkmemVN: string, rewrite_app_beta_conv: conv } - + val state_monad: MONAD_CONSTS = let val memT = TFree ("MemoryType", @{sort type}) val memT = dummyT - + fun mk_stateT tp = Type (@{type_name State_Monad.state}, [memT, tp]) - + val returnN = @{const_name State_Monad.return} fun return tm = Const (returnN, dummyT --> mk_stateT dummyT) $ tm - + val appN = @{const_name State_Monad_Ext.fun_app_lifted} fun app (tm0, tm1) = Const (appN, dummyT) $ tm0 $ tm1 - + fun checkmem'C ctxt = Transform_Misc.get_const_pat ctxt "checkmem'" fun checkmem' ctxt param body = checkmem'C ctxt $ param $ body - + val checkmemVN = "checkmem" val checkmemC = @{const_name "state_mem_defs.checkmem"} fun rewrite_app_beta_conv ctm = case Thm.term_of ctm of Const (@{const_name State_Monad_Ext.fun_app_lifted}, _) $ (Const (@{const_name State_Monad.return}, _) $ Abs _) $ (Const (@{const_name State_Monad.return}, _) $ _) => Conv.rewr_conv @{thm State_Monad_Ext.return_app_return_meta} ctm | _ => Conv.no_conv ctm in { monad_name = "state", mk_stateT = mk_stateT, return = return, app = app, if_termN = @{const_name State_Monad_Ext.if\<^sub>T}, checkmemVN = checkmemVN, rewrite_app_beta_conv = rewrite_app_beta_conv } end val heap_monad: MONAD_CONSTS = let fun mk_stateT tp = Type (@{type_name Heap_Monad.Heap}, [tp]) - + val returnN = @{const_name Heap_Monad.return} fun return tm = Const (returnN, dummyT --> mk_stateT dummyT) $ tm - + val appN = @{const_name Heap_Monad_Ext.fun_app_lifted} fun app (tm0, tm1) = Const (appN, dummyT) $ tm0 $ tm1 - + fun checkmem'C ctxt = Transform_Misc.get_const_pat ctxt "checkmem'" fun checkmem' ctxt param body = checkmem'C ctxt $ param $ body - + val checkmemVN = "checkmem" val checkmemC = @{const_name "heap_mem_defs.checkmem"} fun rewrite_app_beta_conv ctm = case Thm.term_of ctm of Const (@{const_name Heap_Monad_Ext.fun_app_lifted}, _) $ (Const (@{const_name Heap_Monad.return}, _) $ Abs _) $ (Const (@{const_name Heap_Monad.return}, _) $ _) => Conv.rewr_conv @{thm Heap_Monad_Ext.return_app_return_meta} ctm | _ => Conv.no_conv ctm in { monad_name = "heap", mk_stateT = mk_stateT, return = return, app = app, if_termN = @{const_name Heap_Monad_Ext.if\<^sub>T}, checkmemVN = checkmemVN, rewrite_app_beta_conv = rewrite_app_beta_conv } end val monad_consts_dict = [ ("state", state_monad), ("heap", heap_monad) ] fun get_monad_const name = case AList.lookup op= monad_consts_dict name of SOME consts => consts | NONE => error("unrecognized monad: " ^ name ^ " , choices: " ^ commas (map fst monad_consts_dict)); end - diff --git a/thys/Monad_Memo_DP/transform/Transform_Data.ML b/thys/Monad_Memo_DP/transform/Transform_Data.ML --- a/thys/Monad_Memo_DP/transform/Transform_Data.ML +++ b/thys/Monad_Memo_DP/transform/Transform_Data.ML @@ -1,142 +1,167 @@ -structure Transform_Data = struct +signature TRANSFORM_DATA = +sig + type dp_info = { + old_head: term, + new_head': term, + new_headT: term, + + old_defs: thm list, + new_defT: thm, + new_def': thm list + } + type cmd_info = { + scope: binding, + head: term, + locale: string option, + dp_info: dp_info option + } + val get_dp_info: string -> Proof.context -> term -> dp_info option + val get_last_cmd_info: Proof.context -> cmd_info + val commit_dp_info: string -> dp_info -> local_theory -> local_theory + val add_tmp_cmd_info: binding * term * string option -> local_theory -> local_theory + val get_or_last_cmd_info: Proof.context -> (string * term) option -> cmd_info +end + +structure Transform_Data : TRANSFORM_DATA = +struct type dp_info = { old_head: term, new_head': term, new_headT: term, old_defs: thm list, new_defT: thm, new_def': thm list } type cmd_info = { scope: binding, head: term, locale: string option, dp_info: dp_info option } fun map_cmd_info f0 f1 f2 f3 {scope, head, locale, dp_info} = {scope = f0 scope, head = f1 head, locale = f2 locale, dp_info = f3 dp_info} fun map_cmd_dp_info f = map_cmd_info I I I f structure Data = Generic_Data ( type T = { monadified_terms: (string * cmd_info Item_Net.T) list, last_cmd_info: cmd_info option } val empty = { monadified_terms = ["state", "heap"] ~~ replicate 2 (Item_Net.init (op aconv o apply2 #head) (single o #head)), last_cmd_info = NONE } val extend = I fun merge ( {monadified_terms = m0, ...}, {monadified_terms = m1, ...} ) = let val keys0 = map fst m0 val keys1 = map fst m1 val _ = @{assert} (keys0 = keys1) val vals = map Item_Net.merge (map snd m0 ~~ map snd m1) val ms = keys0 ~~ vals in {monadified_terms = ms, last_cmd_info = NONE} end ) fun transform_dp_info phi {old_head, new_head', new_headT, old_defs, new_defT, new_def'} = { old_head = Morphism.term phi old_head, new_head' = Morphism.term phi new_head', new_headT = Morphism.term phi new_headT, - + old_defs = Morphism.fact phi old_defs, new_def' = Morphism.fact phi new_def', new_defT = Morphism.thm phi new_defT } -fun get_monadified_terms_generic monad_name ctxt = - Data.get ctxt +fun get_monadified_terms_generic monad_name context = + Data.get context |> #monadified_terms |> (fn l => AList.lookup op= l monad_name) |> the -fun get_monadified_terms monad_name lthy = - get_monadified_terms_generic monad_name (Context.Proof lthy) +fun get_monadified_terms monad_name ctxt = + get_monadified_terms_generic monad_name (Context.Proof ctxt) fun map_data f0 f1 = Data.map (fn {monadified_terms, last_cmd_info} => {monadified_terms = f0 monadified_terms, last_cmd_info = f1 last_cmd_info}) fun map_monadified_terms f = map_data f I fun map_last_cmd_info f = map_data I f -fun put_monadified_terms_generic monad_name new_terms ctxt = - ctxt |> map_monadified_terms (AList.update op= (monad_name, new_terms)) +fun put_monadified_terms_generic monad_name new_terms context = + context |> map_monadified_terms (AList.update op= (monad_name, new_terms)) -fun map_monadified_terms_generic monad_name f ctxt = - ctxt |> map_monadified_terms (AList.map_entry op= monad_name f) +fun map_monadified_terms_generic monad_name f context = + context |> map_monadified_terms (AList.map_entry op= monad_name f) -fun put_last_cmd_info cmd_info_opt ctxt = - map_last_cmd_info (K cmd_info_opt) ctxt +fun put_last_cmd_info cmd_info_opt context = + map_last_cmd_info (K cmd_info_opt) context -fun get_cmd_info monad_name lthy tm = - get_monadified_terms monad_name lthy +fun get_cmd_info monad_name ctxt tm = + get_monadified_terms monad_name ctxt |> (fn net => Item_Net.retrieve net tm) -fun get_dp_info monad_name lthy tm = - get_cmd_info monad_name lthy tm +fun get_dp_info monad_name ctxt tm = + get_cmd_info monad_name ctxt tm |> (fn result => case result of {dp_info = SOME dp_info', ...} :: _ => SOME dp_info' | _ => NONE) -fun get_last_cmd_info_generic ctxt = - Data.get ctxt +fun get_last_cmd_info_generic context = + Data.get context |> #last_cmd_info |> the -fun get_last_cmd_info lthy = - get_last_cmd_info_generic (Context.Proof lthy) +fun get_last_cmd_info ctxt = + get_last_cmd_info_generic (Context.Proof ctxt) fun commit_dp_info monad_name dp_info = Local_Theory.declaration {pervasive = false, syntax = false} - (fn phi => fn ctxt => + (fn phi => fn context => let - val old_cmd_info = get_last_cmd_info_generic ctxt + val old_cmd_info = get_last_cmd_info_generic context val new_dp_info = transform_dp_info phi dp_info val new_cmd_info = old_cmd_info |> map_cmd_dp_info (K (SOME new_dp_info)) in - ctxt + context |> map_monadified_terms_generic monad_name (Item_Net.update new_cmd_info) |> put_last_cmd_info (SOME new_cmd_info) end) fun add_tmp_cmd_info (scope, head, locale_opt) = Local_Theory.declaration {pervasive = false, syntax = false} - (fn phi => fn ctxt => + (fn phi => fn context => let val new_cmd_info = { scope = Morphism.binding phi scope, head = Morphism.term phi head, locale = locale_opt, dp_info = NONE } in - ctxt |> put_last_cmd_info (SOME new_cmd_info) + context |> put_last_cmd_info (SOME new_cmd_info) end ) -fun get_or_last_cmd_info lthy monad_name_tm_opt = +fun get_or_last_cmd_info ctxt monad_name_tm_opt = case monad_name_tm_opt of - NONE => get_last_cmd_info lthy - | SOME (monad_name, tm) => get_cmd_info monad_name lthy tm |> the_single + NONE => get_last_cmd_info ctxt + | SOME (monad_name, tm) => get_cmd_info monad_name ctxt tm |> the_single end diff --git a/thys/Monad_Memo_DP/transform/Transform_Misc.ML b/thys/Monad_Memo_DP/transform/Transform_Misc.ML --- a/thys/Monad_Memo_DP/transform/Transform_Misc.ML +++ b/thys/Monad_Memo_DP/transform/Transform_Misc.ML @@ -1,73 +1,86 @@ -structure Transform_Misc = struct +signature TRANSFORM_MISC = +sig + val get_const_pat: Proof.context -> string -> term + val totality_of: Function.info -> thm + val rel_of: Function.info -> Proof.context -> thm + val the_element: int list -> int + val add_function: binding -> term list -> local_theory -> Function.info * local_theory + val behead: term -> term -> term * term list + val term_name: term -> string + val locale_term: Proof.context -> string -> string -> term + val locale_thms: Proof.context -> string -> string -> thm list + val uncurry: term -> term +end + +structure Transform_Misc : TRANSFORM_MISC = +struct fun import_function_info term_opt ctxt = case term_opt of SOME tm => (case Function_Common.import_function_data tm ctxt of SOME info => info | NONE => raise TERM("not a function", [tm])) | NONE => (case Function_Common.import_last_function ctxt of SOME info => info | NONE => error "no function defined yet") fun get_const_pat ctxt tm_pat = let val (Const (name, _)) = Proof_Context.read_const {proper=false, strict=false} ctxt tm_pat in Const (name, dummyT) end fun head_of (func_info: Function.info) = #fs func_info |> the_single fun bind_of (func_info: Function.info) = #fnames func_info |> the_single fun totality_of (func_info: Function.info) = func_info |> #totality |> the; fun rel_of (func_info: Function.info) ctxt = Inductive.the_inductive ctxt (#R func_info) |> snd |> #eqs |> the_single; fun the_element l = if tl l |> find_first (not o equal (hd l)) |> is_none then hd l else (@{print} l; error "inconsistent n_args") fun add_function bind defs = let val fixes = [(bind, NONE, NoSyn)]; val specs = map (fn def => (((Binding.empty, []), def), [], [])) defs - val pat_completeness_auto = fn ctxt => - Pat_Completeness.pat_completeness_tac ctxt 1 - THEN auto_tac ctxt in - Function.add_function fixes specs Function_Fun.fun_config pat_completeness_auto + Function.add_function fixes specs Function_Fun.fun_config + (fn ctxt => Pat_Completeness.pat_completeness_tac ctxt 1 THEN auto_tac ctxt) end fun behead head tm = let val head_nargs = strip_comb head |> snd |> length val (tm_head, tm_args) = strip_comb tm val (tm_args0, tm_args1) = chop head_nargs tm_args val tm_head' = list_comb (tm_head, tm_args0) val _ = if Term.aconv_untyped (head, tm_head') then () else raise TERM("head does not match", [head, tm_head']) in (tm_head', tm_args1) end fun term_name tm = if is_Free tm orelse is_Const tm then Term.term_name tm else raise TERM("not an atom, explicit name required", [tm]) - fun locale_term lthy locale_name term_name = - Syntax.read_term lthy (Long_Name.qualify locale_name term_name) + fun locale_term ctxt locale_name term_name = + Syntax.read_term ctxt (Long_Name.qualify locale_name term_name) - fun locale_thms lthy locale_name thms_name = - Proof_Context.get_thms lthy (Long_Name.qualify locale_name thms_name) + fun locale_thms ctxt locale_name thms_name = + Proof_Context.get_thms ctxt (Long_Name.qualify locale_name thms_name) fun uncurry tm = let val arg_typs = fastype_of tm |> binder_types val arg_names = Name.invent_list [] "a" (length arg_typs) val args = map Free (arg_names ~~ arg_typs) val args_tuple = HOLogic.mk_tuple args val tm' = list_comb (tm, args) |> HOLogic.tupled_lambda args_tuple in tm' end end diff --git a/thys/Monad_Memo_DP/transform/Transform_Parser.ML b/thys/Monad_Memo_DP/transform/Transform_Parser.ML --- a/thys/Monad_Memo_DP/transform/Transform_Parser.ML +++ b/thys/Monad_Memo_DP/transform/Transform_Parser.ML @@ -1,44 +1,52 @@ -structure Transform_Parser = struct +signature TRANSFORM_PARSER = +sig + val dp_fun_part1_parser: ((binding * string) * ((bool * (string * Position.T)) * + (string * string) list) option) parser + val dp_fun_part2_parser: (string * (Facts.ref * Token.src list) list) parser +end + +structure Transform_Parser : TRANSFORM_PARSER = +struct val dp_fun_parser = Parse.binding (* name of instantiation and monadified term *) (* fun dp_fun binding = Transform_Data.update_last_binding binding *) val memoizes_parser = Parse.name_position (* name of locale, e.g. dp_consistency_rbt *) val monadifies_parser = Parse.term (* term to be monadified *) -- Scan.option ( @{keyword "("} |-- Parse.thms1 --| (* optional definitions, ".simps" as default *) @{keyword ")"}) val dp_monadify_cmd_parser = Scan.optional (Parse.binding --| Parse.$$$ ":") Binding.empty (* optional scope *) -- Parse.term (* term to be monadified *) -- Scan.option (@{keyword "("} |-- (* optional definitions, ".simps" as default *) Parse.thms1 --| @{keyword ")"}) -- Scan.option (@{keyword with_memory} |-- Parse.name_position) (* e.g. dp_consistency_rbt *) val instance = (Parse.where_ |-- Parse.and_list1 (Parse.name -- (Parse.$$$ "=" |-- Parse.term)) || Scan.succeed []) val dp_fun_part1_parser = (Parse.binding --| Parse.$$$ ":") (* scope, e.g., bf\<^sub>T *) -- Parse.term (* term to be monadified, e.g., bf *) -- Scan.option (@{keyword with_memory} |-- Parse.opt_keyword "default_proof" -- Parse.name_position -- instance ) (* e.g. dp_consistency_rbt *) val dp_fun_part2_parser = (* monadifies *) (@{keyword "("} |-- Parse.name --| @{keyword ")"}) -- Parse.thms1 end diff --git a/thys/Monad_Memo_DP/transform/Transform_Tactic.ML b/thys/Monad_Memo_DP/transform/Transform_Tactic.ML --- a/thys/Monad_Memo_DP/transform/Transform_Tactic.ML +++ b/thys/Monad_Memo_DP/transform/Transform_Tactic.ML @@ -1,141 +1,157 @@ -structure Transform_Tactic = struct +signature TRANSFORM_TACTIC = +sig + val my_print_tac: string -> tactic + val totality_resolve_tac: thm -> thm -> thm -> Proof.context -> tactic + val totality_replay_tac: Function.info -> Function.info -> Proof.context -> tactic + val solve_relator_tac: Proof.context -> int -> tactic + val transfer_raw_tac: Proof.context -> int -> tactic + val step_tac: Proof.context -> Transform_Data.cmd_info -> int -> tactic + val prepare_case_tac: Proof.context -> Transform_Data.cmd_info -> int -> tactic + val prepare_consistentDP_tac: Proof.context -> Transform_Data.cmd_info -> int -> tactic + val solve_consistentDP_tac: Proof.context -> Transform_Data.cmd_info -> int -> tactic + val prepare_combinator_tac: Proof.context -> Transform_Data.cmd_info -> int -> tactic + val dp_unfold_defs_tac: Proof.context -> Transform_Data.cmd_info -> int -> tactic +end + +structure Transform_Tactic : TRANSFORM_TACTIC = +struct fun my_print_tac msg st = (tracing msg; all_tac st) - + fun totality_resolve_tac totality0 def0 def1 ctxt = let val totality0_unfolded = totality0 |> Local_Defs.unfold ctxt [def0] val totality1 = totality0_unfolded |> Local_Defs.fold ctxt [def1] in if Thm.full_prop_of totality0_unfolded aconv Thm.full_prop_of totality1 then let val msg = Pretty.string_of (Pretty.block [ Pretty.str "Failed to transform totality from", Pretty.brk 1, Pretty.quote (Syntax.pretty_term ctxt (Thm.full_prop_of def0)), Pretty.brk 1, Pretty.str "to", Pretty.brk 1, Pretty.quote (Syntax.pretty_term ctxt (Thm.full_prop_of def1)), Pretty.brk 1]) in (*print_tac ctxt msg THEN*) no_tac end else HEADGOAL (resolve_tac ctxt [totality1]) end - + fun totality_blast_tac totality0 def0 def1 ctxt = HEADGOAL ( (resolve_tac ctxt [totality0 RS @{thm rev_iffD1}]) THEN' (resolve_tac ctxt [@{thm arg_cong[where f=HOL.All]}]) THEN' SELECT_GOAL (unfold_tac ctxt (map (Local_Defs.abs_def_rule ctxt) [def0, def1])) THEN' (resolve_tac ctxt [@{thm arg_cong[where f=Wellfounded.accp]}]) THEN' (Blast.depth_tac ctxt 2) ) - + fun totality_replay_tac old_info new_info ctxt = let val totality0 = Transform_Misc.totality_of old_info val def0 = Transform_Misc.rel_of old_info ctxt val def1 = Transform_Misc.rel_of new_info ctxt fun my_print_tac msg st = (tracing msg; all_tac st) in no_tac ORELSE (totality_resolve_tac totality0 def0 def1 ctxt THEN my_print_tac "termination by replaying") ORELSE (totality_blast_tac totality0 def0 def1 ctxt THEN my_print_tac "termination by blast") end fun dp_intro_tac ctxt (cmd_info: Transform_Data.cmd_info) = let val scope_name = Binding.name_of (#scope cmd_info) val consistentDP_rule = Transform_Misc.locale_thms ctxt scope_name "consistentDP_intro" in resolve_tac ctxt consistentDP_rule end fun expand_relator_tac ctxt = SELECT_GOAL (Local_Defs.fold_tac ctxt (Transfer.get_relator_eq ctxt)) fun solve_relator_tac ctxt = SOLVED' (Transfer.eq_tac ctxt) fun split_params_tac ctxt = clarify_tac ctxt fun dp_induct_tac ctxt (cmd_info: Transform_Data.cmd_info) = let val dpT' = cmd_info |> #dp_info |> the |> #new_head' val dpT'_info = Function.get_info ctxt dpT' val induct_rule = dpT'_info |> #inducts |> the in resolve_tac ctxt induct_rule end fun dp_unfold_def_tac ctxt (cmd_info: Transform_Data.cmd_info) sel = cmd_info |> #dp_info |> the |> sel |> map (Local_Defs.meta_rewrite_rule ctxt) - |> Conv.rewrs_conv + |> Conv.rewrs_conv |> Conv.try_conv |> Conv.binop_conv - |> HOLogic.Trueprop_conv + |> HOLogic.Trueprop_conv |> Conv.concl_conv ~1 |> (fn cv => Conv.params_conv ~1 (K cv) ctxt) |> CONVERSION (* |> EqSubst.eqsubst_tac ctxt [0] : may rewrite locale parameters in certain situations *) fun dp_match_rule_tac ctxt (cmd_info: Transform_Data.cmd_info) = let val scope_name = Binding.name_of (#scope cmd_info) val dp_match_rules = Transform_Misc.locale_thms ctxt scope_name "dp_match_rule" in resolve_tac ctxt dp_match_rules end fun checkmem_tac ctxt (cmd_info: Transform_Data.cmd_info) = let val scope_name = Binding.name_of (#scope cmd_info) val dp_match_rules = Transform_Misc.locale_thms ctxt scope_name "crel_vs_checkmem_tupled" in resolve_tac ctxt dp_match_rules THEN' SOLVED' (clarify_tac ctxt) THEN' Transfer.eq_tac ctxt end fun solve_IH_tac ctxt = Method.assm_tac ctxt fun transfer_raw_tac ctxt = resolve_tac ctxt (Transfer.get_transfer_raw ctxt) fun step_tac ctxt (cmd_info: Transform_Data.cmd_info) = solve_IH_tac ctxt ORELSE' solve_relator_tac ctxt ORELSE' dp_match_rule_tac ctxt cmd_info ORELSE' transfer_raw_tac ctxt fun prepare_case_tac ctxt (cmd_info: Transform_Data.cmd_info) = dp_unfold_def_tac ctxt cmd_info #new_def' THEN' checkmem_tac ctxt cmd_info THEN' dp_unfold_def_tac ctxt cmd_info #old_defs fun solve_case_tac ctxt (cmd_info: Transform_Data.cmd_info) = prepare_case_tac ctxt cmd_info THEN' REPEAT_ALL_NEW (step_tac ctxt cmd_info) fun prepare_consistentDP_tac ctxt (cmd_info: Transform_Data.cmd_info) = dp_intro_tac ctxt cmd_info THEN' expand_relator_tac ctxt THEN' split_params_tac ctxt THEN' dp_induct_tac ctxt cmd_info fun solve_consistentDP_tac ctxt (cmd_info: Transform_Data.cmd_info) = prepare_consistentDP_tac ctxt cmd_info THEN_ALL_NEW SOLVED' (solve_case_tac ctxt cmd_info) fun prepare_combinator_tac ctxt (cmd_info: Transform_Data.cmd_info) = EqSubst.eqsubst_tac ctxt [0] @{thms Rel_def[symmetric]} THEN' dp_unfold_def_tac ctxt cmd_info (single o #new_defT) THEN' REPEAT_ALL_NEW (resolve_tac ctxt (@{thm Rel_abs} :: Transform_Misc.locale_thms ctxt "local" "crel_vs_return_ext")) THEN' (SELECT_GOAL (unfold_tac ctxt @{thms Rel_def})) fun dp_unfold_defs_tac ctxt (cmd_info: Transform_Data.cmd_info) = dp_unfold_def_tac ctxt cmd_info #new_def' THEN' dp_unfold_def_tac ctxt cmd_info #old_defs end diff --git a/thys/Monad_Memo_DP/transform/Transform_Term.ML b/thys/Monad_Memo_DP/transform/Transform_Term.ML --- a/thys/Monad_Memo_DP/transform/Transform_Term.ML +++ b/thys/Monad_Memo_DP/transform/Transform_Term.ML @@ -1,340 +1,345 @@ -fun list_conv (head_conv, arg_convs) lthy = - Library.foldl (uncurry Conv.combination_conv) (head_conv lthy, map (fn conv => conv lthy) arg_convs) +signature TRANSFORM_TERM = +sig + val repeat_sweep_conv: (Proof.context -> conv) -> Proof.context -> conv + val rewrite_pureapp_beta_conv: conv + val wrap_head: Transform_Const.MONAD_CONSTS -> term -> int -> term + val lift_equation: Transform_Const.MONAD_CONSTS -> Proof.context -> + term * term -> term option -> (Proof.context -> conv) * term * int +end + +structure Transform_Term : TRANSFORM_TERM = +struct + +fun list_conv (head_conv, arg_convs) ctxt = + Library.foldl (uncurry Conv.combination_conv) + (head_conv ctxt, map (fn conv => conv ctxt) arg_convs) fun eta_conv1 ctxt = (Conv.abs_conv (K Conv.all_conv) ctxt) else_conv (Thm.eta_long_conversion then_conv Conv.abs_conv (K Thm.eta_conversion) ctxt) fun eta_conv_n n = funpow n (fn conv => fn ctxt => eta_conv1 ctxt then_conv Conv.abs_conv (fn (_, ctxt) => conv ctxt) ctxt) (K Conv.all_conv) fun conv_changed conv ctm = let val eq = conv ctm in if Thm.is_reflexive eq then Conv.no_conv ctm else eq end fun repeat_sweep_conv conv = Conv.repeat_conv o conv_changed o Conv.top_sweep_conv conv val app_mark_conv = Conv.rewr_conv @{thm App_def[symmetric]} val app_unmark_conv = Conv.rewr_conv @{thm App_def} val wrap_mark_conv = Conv.rewr_conv @{thm Wrap_def[symmetric]} -structure Transform_Term = struct - fun eta_expand tm = let val n_args = Integer.min 1 (length (binder_types (fastype_of tm))) val (args, body) = Term.strip_abs_eta n_args tm in Library.foldr (uncurry Term.absfree) (args, body) end fun is_ctr_sugar ctxt tp_name = is_some (Ctr_Sugar.ctr_sugar_of ctxt tp_name) fun type_nargs tp = tp |> strip_type |> fst |> length fun term_nargs tm = type_nargs (fastype_of tm) -fun - lift_type (monad_consts: Transform_Const.MONAD_CONSTS) ctxt tp = #mk_stateT monad_consts (lift_type' monad_consts ctxt tp) -and - lift_type' monad_consts ctxt (tp as Type (@{type_name fun}, _)) - = lift_type' monad_consts ctxt (domain_type tp) --> lift_type monad_consts ctxt (range_type tp) -| lift_type' monad_consts ctxt (tp as Type (tp_name, tp_args)) - = if is_ctr_sugar ctxt tp_name then Type (tp_name, map (lift_type' monad_consts ctxt) tp_args) +fun lift_type (monad_consts: Transform_Const.MONAD_CONSTS) ctxt tp = + #mk_stateT monad_consts (lift_type' monad_consts ctxt tp) +and lift_type' monad_consts ctxt (tp as Type (@{type_name fun}, _)) = + lift_type' monad_consts ctxt (domain_type tp) --> lift_type monad_consts ctxt (range_type tp) + | lift_type' monad_consts ctxt (tp as Type (tp_name, tp_args)) = + if is_ctr_sugar ctxt tp_name then Type (tp_name, map (lift_type' monad_consts ctxt) tp_args) else if null tp_args then tp (* int, nat, \ *) else raise TYPE("not a ctr_sugar", [tp], []) -| lift_type' _ _ tp = tp + | lift_type' _ _ tp = tp fun is_atom_type monad_consts ctxt tp = tp = lift_type' monad_consts ctxt tp fun is_1st_type monad_consts ctxt tp = body_type tp :: binder_types tp |> forall (is_atom_type monad_consts ctxt) fun orig_atom ctxt atom_name = Proof_Context.read_term_pattern ctxt atom_name fun is_1st_term monad_consts ctxt tm = is_1st_type monad_consts ctxt (fastype_of tm) fun is_1st_atom monad_consts ctxt atom_name = is_1st_term monad_consts ctxt (orig_atom ctxt atom_name) fun wrap_1st_term monad_consts ctxt tm n_args_opt inner_wrap = let val n_args = the_default (term_nargs tm) n_args_opt val (vars_name_typ, body) = Term.strip_abs_eta n_args tm fun wrap (name_typ, (conv, tm)) = ( eta_conv1 ctxt then_conv Conv.abs_conv (K conv) ctxt then_conv wrap_mark_conv, #return monad_consts (Term.absfree name_typ tm) ) val (conv, result) = Library.foldr wrap (vars_name_typ, ( - if inner_wrap then (wrap_mark_conv, #return monad_consts body) else (Conv.all_conv, body) + if inner_wrap then (wrap_mark_conv, #return monad_consts body) else (Conv.all_conv, body) )) - in (K conv, result) end fun lift_1st_atom monad_consts ctxt atom (name, tp) = let val (arg_typs, body_typ) = strip_type tp val n_args = term_nargs (orig_atom ctxt name) val (arg_typs, body_arg_typs) = chop n_args arg_typs val arg_typs' = map (lift_type' monad_consts ctxt) arg_typs val body_typ' = lift_type' monad_consts ctxt (body_arg_typs ---> body_typ) val tm' = atom (name, arg_typs' ---> body_typ') (* " *) in wrap_1st_term monad_consts ctxt tm' (SOME n_args) true end fun fixed_args head_n_args tm = let val (tm_head, tm_args) = strip_comb tm val n_tm_args = length tm_args in head_n_args tm_head |> Option.mapPartial (fn n_args => if n_tm_args > n_args then NONE else if n_tm_args < n_args then raise TERM("need " ^ string_of_int n_args ^ " args", [tm]) else SOME (tm_head, tm_args)) end fun lift_abs' monad_consts ctxt (name, typ) cont lift_dict body = let val free = Free (name, typ) val typ' = lift_type' monad_consts ctxt typ val freeT' = Free (name, typ') val freeT = #return monad_consts (freeT') val lift_dict' = if is_atom_type monad_consts ctxt typ then lift_dict else (free, (K wrap_mark_conv, freeT))::lift_dict val (conv_free, body_free) = (cont (lift_dict') body) val body' = lambda freeT' body_free - fun conv lthy = - eta_conv1 ctxt then_conv Conv.abs_conv (fn (_, lthy') => conv_free lthy') lthy - in - (conv, body') - end + fun conv ctxt' = eta_conv1 ctxt then_conv Conv.abs_conv (conv_free o #2) ctxt' + in (conv, body') end fun lift_arg monad_consts ctxt lift_dict tm = (* let val (conv, tm') = lift_term ctxt lift_dict (eta_expand tm) fun conv' ctxt = Conv.try_conv (eta_conv1 ctxt) then_conv (conv ctxt) in (conv', tm') end eta_expand AFTER lifting *) lift_term monad_consts ctxt lift_dict tm and lift_term monad_consts ctxt lift_dict tm = let val case_terms = Ctr_Sugar.ctr_sugars_of ctxt |> map #casex fun lookup_case_term tm = find_first (fn x => Term.aconv_untyped (x, tm)) case_terms val check_cont = lift_term monad_consts ctxt val check_cont_arg = lift_arg monad_consts ctxt fun check_const tm = case tm of Const (_, typ) => ( case Transform_Data.get_dp_info (#monad_name monad_consts) ctxt tm of SOME {new_headT=Const (name, _), ...} => SOME (K Conv.all_conv, Const (name, lift_type monad_consts ctxt typ)) | SOME {new_headT=tm', ...} => raise TERM("not a constant", [tm']) | NONE => NONE) | _ => NONE fun check_1st_atom tm = case tm of Const (name, typ) => if is_1st_atom monad_consts ctxt name then SOME (lift_1st_atom monad_consts ctxt Const (name, typ)) else NONE | Free (name, typ) => if is_1st_atom monad_consts ctxt name then SOME (lift_1st_atom monad_consts ctxt Free (name, typ)) else NONE | _ => (* if is_1st_term ctxt tm andalso exists_subterm (AList.defined (op aconv) lift_dict) tm then SOME (wrap_1st_term tm NONE) else *) NONE (* fun check_dict tm = (* TODO: map -> mapT, dummyT *) AList.lookup Term.aconv_untyped lift_dict tm |> Option.map (fn tm' => if is_Const tm then (@{assert} (is_Const tm'); map_types (K (lift_type ctxt (type_of tm))) tm') else tm') *) fun check_dict tm = AList.lookup Term.aconv_untyped lift_dict tm fun check_if tm = fixed_args (fn head => if Term.aconv_untyped (head, @{term If}) then SOME 3 else NONE) tm |> Option.map (fn (_, args) => let val (arg_convs, args') = map (check_cont lift_dict) args |> split_list val conv = list_conv (K Conv.all_conv, arg_convs) val tm' = list_comb (Const (#if_termN monad_consts, dummyT), args') in (conv, tm') end) fun check_abs tm = case tm of Abs (name, typ, body) => let val (name', body') = Term.dest_abs (name, typ, body) val (conv, tm') = lift_abs' monad_consts ctxt (name', typ) check_cont lift_dict body' - fun convT lthy = conv lthy then_conv wrap_mark_conv + fun convT ctxt' = conv ctxt' then_conv wrap_mark_conv val tmT = #return monad_consts tm' in SOME (convT, tmT) end | _ => NONE fun check_case tm = fixed_args (lookup_case_term #> Option.map (fn cs => term_nargs cs - 1)) tm |> Option.map (fn (head, args) => let val (case_name, case_type) = lookup_case_term head |> the |> dest_Const val ((clause_typs, _), _) = strip_type case_type |>> split_last - + val clase_nparams = clause_typs |> map type_nargs (* ('a\'b) \ ('a\'b) |> type_nargs = 1*) - + fun lift_clause n_param clause = let val (vars_name_typ, body) = Term.strip_abs_eta n_param clause val abs_lift_wraps = map (lift_abs' monad_consts ctxt) vars_name_typ val lift_wrap = Library.foldr (op o) (abs_lift_wraps, I) check_cont val (conv, clauseT) = lift_wrap lift_dict body in (conv, clauseT) end - + val head' = Const (case_name, dummyT) (* clauses are sufficient for type inference *) val (convs, clauses') = map2 lift_clause clase_nparams args |> split_list - fun conv lthy = list_conv (K Conv.all_conv, convs) lthy then_conv wrap_mark_conv + fun conv ctxt' = list_conv (K Conv.all_conv, convs) ctxt' then_conv wrap_mark_conv val tm' = #return monad_consts (list_comb (head', clauses')) in (conv, tm') end) fun check_app tm = case tm of f $ x => let val (f_conv, tmf) = check_cont lift_dict f val (x_conv, tmx) = check_cont_arg lift_dict x val tm' = #app monad_consts (tmf, tmx) - fun conv lthy = Conv.combination_conv (f_conv lthy then_conv app_mark_conv) (x_conv lthy) + fun conv ctxt' = Conv.combination_conv (f_conv ctxt' then_conv app_mark_conv) (x_conv ctxt') in SOME (conv, tm') end | _ => NONE fun check_pure tm = if tm |> exists_subterm (AList.defined (op aconv) lift_dict) orelse not (is_1st_term monad_consts ctxt tm) then NONE else SOME (wrap_1st_term monad_consts ctxt tm NONE true) fun choke tm = raise TERM("cannot process term", [tm]) val checks = [ check_pure, check_const, check_case, check_if, check_abs, check_app, check_dict, check_1st_atom, choke ] in get_first (fn check => check tm) checks |> the end fun rewrite_pureapp_beta_conv ctm = case Thm.term_of ctm of Const (@{const_name Pure_Monad.App}, _) $ (Const (@{const_name Pure_Monad.Wrap}, _) $ Abs _) $ (Const (@{const_name Pure_Monad.Wrap}, _) $ _) => Conv.rewr_conv @{thm Wrap_App_Wrap} ctm | _ => Conv.no_conv ctm fun monadify monad_consts ctxt tm = let val (_, tm) = lift_term monad_consts ctxt [] tm (*val tm = rewrite_return_app_return tm*) val tm = Syntax.check_term ctxt tm in tm end fun wrap_head (monad_consts: Transform_Const.MONAD_CONSTS) head n_args = Library.foldr (fn (typ, tm) => #return monad_consts (absdummy typ tm)) (replicate n_args dummyT, list_comb (head, rev (map_range Bound n_args))) fun lift_head monad_consts ctxt head n_args = let val dest_head = if is_Const head then dest_Const else dest_Free val (head_name, head_typ) = dest_head head val (arg_typs, body_typ) = strip_type head_typ val (arg_typs0, arg_typs1) = chop n_args arg_typs val arg_typs0' = map (lift_type' monad_consts ctxt) arg_typs0 val arg_typs1T = lift_type monad_consts ctxt (arg_typs1 ---> body_typ) val head_typ' = arg_typs0' ---> arg_typs1T val head' = Free (head_name, head_typ') val (head_conv, headT) = wrap_1st_term monad_consts ctxt head' (SOME n_args) false in (head', (head_conv, headT)) end fun lift_equation monad_consts ctxt (lhs, rhs) memoizer_opt = let val (head, args) = strip_comb lhs val n_args = length args val (head', (head_conv, headT)) = lift_head monad_consts ctxt head n_args val args' = args |> map (map_aterms (fn tm => tm |> map_types (if is_Const tm then K dummyT else lift_type' monad_consts ctxt))) val lhs' = list_comb (head', args') val frees = fold Term.add_frees args [] |> filter_out (is_atom_type monad_consts ctxt o snd) - + val lift_dict_args = frees |> map (fn (name, typ) => ( - Free (name, typ), + Free (name, typ), (K wrap_mark_conv, #return monad_consts (Free (name, lift_type' monad_consts ctxt typ))) )) val lift_dict = (head, (head_conv, headT)) :: lift_dict_args val (rhs_conv, rhsT) = lift_term monad_consts ctxt lift_dict rhs val rhsT_memoized = case memoizer_opt of SOME memoizer => memoizer $ HOLogic.mk_tuple args $ rhsT | NONE => rhsT val eqs' = (lhs', rhsT_memoized) |> HOLogic.mk_eq |> HOLogic.mk_Trueprop in (rhs_conv, eqs', n_args) end - end diff --git a/thys/Monad_Memo_DP/util/Ground_Function.ML b/thys/Monad_Memo_DP/util/Ground_Function.ML --- a/thys/Monad_Memo_DP/util/Ground_Function.ML +++ b/thys/Monad_Memo_DP/util/Ground_Function.ML @@ -1,39 +1,47 @@ (** Define a new ground constant from an existing function definition **) -structure Ground_Function = struct + +signature GROUND_FUNCTION = +sig + val mk_fun: bool -> thm list -> binding -> local_theory -> local_theory +end + +structure Ground_Function : GROUND_FUNCTION = +struct fun add_function bind defs = let val fixes = [(bind, NONE, NoSyn)]; val specs = map (fn def => (((Binding.empty, []), def), [], [])) defs - val pat_completeness_auto = fn ctxt => - Pat_Completeness.pat_completeness_tac ctxt 1 - THEN auto_tac ctxt in - Function.add_function fixes specs Function_Fun.fun_config pat_completeness_auto + Function.add_function fixes specs Function_Fun.fun_config + (fn ctxt => Pat_Completeness.pat_completeness_tac ctxt 1 THEN auto_tac ctxt) end fun dest_hol_eq_prop t = let val Const ("HOL.Trueprop", _) $ (Const ("HOL.eq", _) $ a $ b) = t in (a, b) end fun get_fun_head t = let val (t, _) = dest_hol_eq_prop t val t = Term.head_of t val Const (fun_name, fun_ty) = t in (fun_name, fun_ty) end -fun mk_fun termination simps binding ctxt = +fun mk_fun termination simps binding lthy = let val eqns = map Thm.concl_of simps - val (eqns, _) = Variable.import_terms true eqns ctxt + val (eqns, _) = Variable.import_terms true eqns lthy val (f_name, f_ty) = get_fun_head (hd eqns) val s = Binding.name_of binding val replacement = (Const (f_name, f_ty), Free (s, f_ty)) val eqns = map (subst_free [replacement]) eqns - val (_, ctxt) = add_function binding eqns ctxt - fun prove_termination lthy = - Function.prove_termination NONE (Function_Common.termination_prover_tac false lthy) lthy - in ctxt |> (if termination then snd o prove_termination else I) end + fun prove_termination lthy' = lthy' + |> Function.prove_termination NONE (Function_Common.termination_prover_tac false lthy') + in + lthy + |> add_function binding eqns |> #2 + |> termination ? (prove_termination #> #2) + end end