diff --git a/thys/Refine_Imperative_HOL/Sepref_Frame.thy b/thys/Refine_Imperative_HOL/Sepref_Frame.thy --- a/thys/Refine_Imperative_HOL/Sepref_Frame.thy +++ b/thys/Refine_Imperative_HOL/Sepref_Frame.thy @@ -1,541 +1,557 @@ section \Frame Inference\ theory Sepref_Frame imports Sepref_Basic Sepref_Constraints begin text \ In this theory, we provide a specific frame inference tactic for Sepref. The first tactic, \frame_tac\, is a standard frame inference tactic, based on the assumption that only @{const hn_ctxt}-assertions need to be matched. The second tactic, \merge_tac\, resolves entailments of the form \F1 \\<^sub>A F2 \\<^sub>t ?F\ that occur during translation of if and case statements. It synthesizes a new frame ?F, where refinements of variables with equal refinements in \F1\ and \F2\ are preserved, and the others are set to @{const hn_invalid}. \ definition mismatch_assn :: "('a \ 'c \ assn) \ ('a \ 'c \ assn) \ 'a \ 'c \ assn" where "mismatch_assn R1 R2 x y \ R1 x y \\<^sub>A R2 x y" abbreviation "hn_mismatch R1 R2 \ hn_ctxt (mismatch_assn R1 R2)" lemma recover_pure_aux: "CONSTRAINT is_pure R \ hn_invalid R x y \\<^sub>t hn_ctxt R x y" by (auto simp: is_pure_conv invalid_pure_recover hn_ctxt_def) lemma frame_thms: "P \\<^sub>t P" "P\\<^sub>tP' \ F\\<^sub>tF' \ F*P \\<^sub>t F'*P'" "hn_ctxt R x y \\<^sub>t hn_invalid R x y" "hn_ctxt R x y \\<^sub>t hn_ctxt (\_ _. true) x y" "CONSTRAINT is_pure R \ hn_invalid R x y \\<^sub>t hn_ctxt R x y" apply - applyS simp applyS (rule entt_star_mono; assumption) subgoal apply (simp add: hn_ctxt_def) apply (rule enttI) apply (rule ent_trans[OF invalidate[of R]]) by solve_entails applyS (sep_auto simp: hn_ctxt_def) applyS (erule recover_pure_aux) done named_theorems_rev sepref_frame_match_rules \Sepref: Additional frame rules\ text \Rules to discharge unmatched stuff\ (*lemma frame_rem_thms: "P \\<^sub>t P" "P \\<^sub>t emp" by sep_auto+ *) lemma frame_rem1: "P\\<^sub>tP" by simp lemma frame_rem2: "F \\<^sub>t F' \ F * hn_ctxt A x y \\<^sub>t F' * hn_ctxt A x y" apply (rule entt_star_mono) by auto lemma frame_rem3: "F \\<^sub>t F' \ F * hn_ctxt A x y \\<^sub>t F'" using frame_thms(2) by fastforce lemma frame_rem4: "P \\<^sub>t emp" by simp lemmas frame_rem_thms = frame_rem1 frame_rem2 frame_rem3 frame_rem4 named_theorems_rev sepref_frame_rem_rules \Sepref: Additional rules to resolve remainder of frame-pairing\ lemma ent_disj_star_mono: "\ A \\<^sub>A C \\<^sub>A E; B \\<^sub>A D \\<^sub>A F \ \ A*B \\<^sub>A C*D \\<^sub>A E*F" by (metis ent_disjI1 ent_disjI2 ent_disjE ent_star_mono) lemma entt_disj_star_mono: "\ A \\<^sub>A C \\<^sub>t E; B \\<^sub>A D \\<^sub>t F \ \ A*B \\<^sub>A C*D \\<^sub>t E*F" proof - assume a1: "A \\<^sub>A C \\<^sub>t E" assume "B \\<^sub>A D \\<^sub>t F" then have "A * B \\<^sub>A C * D \\<^sub>A true * E * (true * F)" using a1 by (simp add: ent_disj_star_mono enttD) then show ?thesis by (metis (no_types) assn_times_comm enttI merge_true_star_ctx star_aci(3)) qed lemma hn_merge1: (*"emp \\<^sub>A emp \\<^sub>A emp"*) "F \\<^sub>A F \\<^sub>t F" "\ hn_ctxt R1 x x' \\<^sub>A hn_ctxt R2 x x' \\<^sub>t hn_ctxt R x x'; Fl \\<^sub>A Fr \\<^sub>t F \ \ Fl * hn_ctxt R1 x x' \\<^sub>A Fr * hn_ctxt R2 x x' \\<^sub>t F * hn_ctxt R x x'" apply simp by (rule entt_disj_star_mono; simp) lemma hn_merge2: "hn_invalid R x x' \\<^sub>A hn_ctxt R x x' \\<^sub>t hn_invalid R x x'" "hn_ctxt R x x' \\<^sub>A hn_invalid R x x' \\<^sub>t hn_invalid R x x'" by (sep_auto eintros: invalidate ent_disjE intro!: ent_imp_entt simp: hn_ctxt_def)+ lemma invalid_assn_mono: "hn_ctxt A x y \\<^sub>t hn_ctxt B x y \ hn_invalid A x y \\<^sub>t hn_invalid B x y" by (clarsimp simp: invalid_assn_def entailst_def entails_def hn_ctxt_def) (force simp: mod_star_conv) lemma hn_merge3: (* Not used *) "\NO_MATCH (hn_invalid XX) R2; hn_ctxt R1 x x' \\<^sub>A hn_ctxt R2 x x' \\<^sub>t hn_ctxt Rm x x'\ \ hn_invalid R1 x x' \\<^sub>A hn_ctxt R2 x x' \\<^sub>t hn_invalid Rm x x'" "\NO_MATCH (hn_invalid XX) R1; hn_ctxt R1 x x' \\<^sub>A hn_ctxt R2 x x' \\<^sub>t hn_ctxt Rm x x'\ \ hn_ctxt R1 x x' \\<^sub>A hn_invalid R2 x x' \\<^sub>t hn_invalid Rm x x'" apply (meson entt_disjD1 entt_disjD2 entt_disjE entt_trans frame_thms(3) invalid_assn_mono) apply (meson entt_disjD1 entt_disjD2 entt_disjE entt_trans frame_thms(3) invalid_assn_mono) done lemmas merge_thms = hn_merge1 hn_merge2 named_theorems sepref_frame_merge_rules \Sepref: Additional merge rules\ lemma hn_merge_mismatch: "hn_ctxt R1 x x' \\<^sub>A hn_ctxt R2 x x' \\<^sub>t hn_mismatch R1 R2 x x'" by (sep_auto simp: hn_ctxt_def mismatch_assn_def) lemma is_merge: "P1\\<^sub>AP2\\<^sub>tP \ P1\\<^sub>AP2\\<^sub>tP" . lemma merge_mono: "\A\\<^sub>tA'; B\\<^sub>tB'; A'\\<^sub>AB' \\<^sub>t C\ \ A\\<^sub>AB \\<^sub>t C" by (meson entt_disjE entt_disjI1_direct entt_disjI2_direct entt_trans) text \Apply forward rule on left or right side of merge\ lemma gen_merge_cons1: "\A\\<^sub>tA'; A'\\<^sub>AB \\<^sub>t C\ \ A\\<^sub>AB \\<^sub>t C" by (meson merge_mono entt_refl) lemma gen_merge_cons2: "\B\\<^sub>tB'; A\\<^sub>AB' \\<^sub>t C\ \ A\\<^sub>AB \\<^sub>t C" by (meson merge_mono entt_refl) lemmas gen_merge_cons = gen_merge_cons1 gen_merge_cons2 text \These rules are applied to recover pure values that have been destroyed by rule application\ definition "RECOVER_PURE P Q \ P \\<^sub>t Q" lemma recover_pure: "RECOVER_PURE emp emp" "\RECOVER_PURE P2 Q2; RECOVER_PURE P1 Q1\ \ RECOVER_PURE (P1*P2) (Q1*Q2)" "CONSTRAINT is_pure R \ RECOVER_PURE (hn_invalid R x y) (hn_ctxt R x y)" "RECOVER_PURE (hn_ctxt R x y) (hn_ctxt R x y)" unfolding RECOVER_PURE_def subgoal by sep_auto subgoal by (drule (1) entt_star_mono) subgoal by (rule recover_pure_aux) subgoal by sep_auto done lemma recover_pure_triv: "RECOVER_PURE P P" unfolding RECOVER_PURE_def by sep_auto text \Weakening the postcondition by converting @{const invalid_assn} to @{term "\_ _. true"}\ definition "WEAKEN_HNR_POST \ \' \'' \ (\h. h\\) \ (\'' \\<^sub>t \')" lemma weaken_hnr_postI: assumes "WEAKEN_HNR_POST \ \'' \'" assumes "hn_refine \ c \' R a" shows "hn_refine \ c \'' R a" apply (rule hn_refine_preI) apply (rule hn_refine_cons_post) apply (rule assms) using assms(1) unfolding WEAKEN_HNR_POST_def by blast lemma weaken_hnr_post_triv: "WEAKEN_HNR_POST \ P P" unfolding WEAKEN_HNR_POST_def by sep_auto lemma weaken_hnr_post: "\WEAKEN_HNR_POST \ P P'; WEAKEN_HNR_POST \' Q Q'\ \ WEAKEN_HNR_POST (\*\') (P*Q) (P'*Q')" "WEAKEN_HNR_POST (hn_ctxt R x y) (hn_ctxt R x y) (hn_ctxt R x y)" "WEAKEN_HNR_POST (hn_ctxt R x y) (hn_invalid R x y) (hn_ctxt (\_ _. true) x y)" proof (goal_cases) case 1 thus ?case unfolding WEAKEN_HNR_POST_def apply clarsimp apply (rule entt_star_mono) by (auto simp: mod_star_conv) next case 2 thus ?case by (rule weaken_hnr_post_triv) next case 3 thus ?case unfolding WEAKEN_HNR_POST_def by (sep_auto simp: invalid_assn_def hn_ctxt_def) qed lemma reorder_enttI: assumes "A*true = C*true" assumes "B*true = D*true" shows "(A\\<^sub>tB) \ (C\\<^sub>tD)" apply (intro eq_reflection) unfolding entt_def_true by (simp add: assms) lemma merge_sat1: "(A\\<^sub>AA' \\<^sub>t Am) \ (A\\<^sub>AAm \\<^sub>t Am)" using entt_disjD1 entt_disjE by blast lemma merge_sat2: "(A\\<^sub>AA' \\<^sub>t Am) \ (Am\\<^sub>AA' \\<^sub>t Am)" using entt_disjD2 entt_disjE by blast ML \ signature SEPREF_FRAME = sig (* Check if subgoal is a frame obligation *) (*val is_frame : term -> bool *) (* Check if subgoal is a merge obligation *) val is_merge: term -> bool (* Perform frame inference *) val frame_tac: (Proof.context -> tactic') -> Proof.context -> tactic' (* Perform merging *) val merge_tac: (Proof.context -> tactic') -> Proof.context -> tactic' val frame_step_tac: (Proof.context -> tactic') -> bool -> Proof.context -> tactic' (* Reorder frame *) val prepare_frame_tac : Proof.context -> tactic' (* Solve a RECOVER_PURE goal, inserting constraints as necessary *) val recover_pure_tac: Proof.context -> tactic' (* Split precondition of hnr-goal into frame and arguments *) val align_goal_tac: Proof.context -> tactic' (* Normalize goal's precondition *) val norm_goal_pre_tac: Proof.context -> tactic' (* Rearrange precondition of hnr-term according to parameter order, normalize all relations *) val align_rl_conv: Proof.context -> conv (* Convert hn_invalid to \_ _. true in postcondition of hnr-goal. Makes proving the goal easier.*) val weaken_post_tac: Proof.context -> tactic' val add_normrel_eq : thm -> Context.generic -> Context.generic val del_normrel_eq : thm -> Context.generic -> Context.generic val get_normrel_eqs : Proof.context -> thm list val cfg_debug: bool Config.T val setup: theory -> theory end structure Sepref_Frame : SEPREF_FRAME = struct val cfg_debug = Attrib.setup_config_bool @{binding sepref_debug_frame} (K false) val DCONVERSION = Sepref_Debugging.DBG_CONVERSION cfg_debug val dbg_msg_tac = Sepref_Debugging.dbg_msg_tac cfg_debug structure normrel_eqs = Named_Thms ( val name = @{binding sepref_frame_normrel_eqs} val description = "Equations to normalize relations for frame matching" ) val add_normrel_eq = normrel_eqs.add_thm val del_normrel_eq = normrel_eqs.del_thm val get_normrel_eqs = normrel_eqs.get val mk_entailst = HOLogic.mk_binrel @{const_name "entailst"} local open Sepref_Basic Refine_Util Conv fun assn_ord p = case apply2 dest_hn_ctxt_opt p of (NONE,NONE) => EQUAL | (SOME _, NONE) => LESS | (NONE, SOME _) => GREATER | (SOME (_,a,_), SOME (_,a',_)) => Term_Ord.fast_term_ord (a,a') in fun reorder_ctxt_conv ctxt ct = let val cert = Thm.cterm_of ctxt val new_ct = Thm.term_of ct |> strip_star |> sort assn_ord |> list_star |> cert val thm = Goal.prove_internal ctxt [] (mk_cequals (ct,new_ct)) (fn _ => simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms star_aci}) 1) in thm end fun prepare_fi_conv ctxt ct = case Thm.term_of ct of - @{mpat "?P \\<^sub>t ?Q"} => let - val cert = Thm.cterm_of ctxt + (t as @{mpat "?P \\<^sub>t ?Q"}) => let (* Build table from abs-vars to ctxt *) val (Qm, Qum) = strip_star Q |> filter_out is_true |> List.partition is_hn_ctxt val Qtab = ( Qm |> map (fn x => (#2 (dest_hn_ctxt x),(NONE,x))) |> Termtab.make ) handle e as (Termtab.DUP _) => ( tracing ("Dup heap: " ^ @{make_string} ct); raise e) (* Go over entries in P and try to find a partner *) val (Qtab,Pum) = fold (fn a => fn (Qtab,Pum) => case dest_hn_ctxt_opt a of NONE => (Qtab,a::Pum) | SOME (_,p,_) => ( case Termtab.lookup Qtab p of SOME (NONE,tg) => (Termtab.update (p,(SOME a,tg)) Qtab, Pum) | _ => (Qtab,a::Pum) ) ) (strip_star P) (Qtab,[]) val Pum = filter_out is_true Pum (* Read out information from Qtab *) val (pairs,Qum2) = Termtab.dest Qtab |> map #2 |> List.partition (is_some o #1) |> apfst (map (apfst the)) |> apsnd (map #2) (* Build reordered terms: P' = fst pairs * Pum, Q' = snd pairs * (Qum2*Qum) *) val P' = mk_star (list_star (map fst pairs), list_star Pum) val Q' = mk_star (list_star (map snd pairs), list_star (Qum2@Qum)) - val new_ct = mk_entailst (P', Q') |> cert - - val msg_tac = dbg_msg_tac (Sepref_Debugging.msg_allgoals "Solving frame permutation") ctxt 1 - val tac = msg_tac THEN ALLGOALS (resolve_tac ctxt @{thms reorder_enttI}) THEN star_permute_tac ctxt + val new_t = mk_entailst (P', Q') + val goal_t = Logic.mk_equals (t,new_t) - val thm = Goal.prove_internal ctxt [] (mk_cequals (ct,new_ct)) (fn _ => tac) + val goal_ctxt = Variable.declare_term goal_t ctxt + + val msg_tac = dbg_msg_tac (Sepref_Debugging.msg_allgoals "Solving frame permutation") goal_ctxt 1 + val tac = + msg_tac + THEN ALLGOALS (resolve_tac goal_ctxt @{thms reorder_enttI}) + THEN star_permute_tac goal_ctxt + + val goal_ct = Thm.cterm_of ctxt goal_t + + val thm = Goal.prove_internal ctxt [] goal_ct (fn _ => tac) in thm end | _ => no_conv ct end fun is_merge @{mpat "Trueprop (_ \\<^sub>A _ \\<^sub>t _)"} = true | is_merge _ = false fun is_gen_frame @{mpat "Trueprop (_ \\<^sub>t _)"} = true | is_gen_frame _ = false fun prepare_frame_tac ctxt = let open Refine_Util Conv val frame_ss = put_simpset HOL_basic_ss ctxt addsimps @{thms mult_1_right[where 'a=assn] mult_1_left[where 'a=assn]} in CONVERSION Thm.eta_conversion THEN' (*CONCL_COND' is_frame THEN'*) simp_tac frame_ss THEN' - CONVERSION (HOL_concl_conv (fn _ => prepare_fi_conv ctxt) ctxt) + CONVERSION (HOL_concl_conv prepare_fi_conv ctxt) end local fun wrap_side_tac side_tac dbg tac = tac THEN_ALL_NEW_FWD ( CONCL_COND' is_gen_frame ORELSE' (if dbg then TRY_SOLVED' else SOLVED') side_tac ) in fun frame_step_tac side_tac dbg ctxt = let open Refine_Util Conv (* Constraint solving is built-in *) val side_tac = Sepref_Constraints.constraint_tac ctxt ORELSE' side_tac ctxt val frame_thms = @{thms frame_thms} @ Named_Theorems_Rev.get ctxt @{named_theorems_rev sepref_frame_match_rules} val merge_thms = @{thms merge_thms} @ Named_Theorems.get ctxt @{named_theorems sepref_frame_merge_rules} val ss = put_simpset HOL_basic_ss ctxt addsimps normrel_eqs.get ctxt fun frame_thm_tac dbg = wrap_side_tac side_tac dbg (resolve_tac ctxt frame_thms) fun merge_thm_tac dbg = wrap_side_tac side_tac dbg (resolve_tac ctxt merge_thms) fun thm_tac dbg = CONCL_COND' is_merge THEN_ELSE' (merge_thm_tac dbg, frame_thm_tac dbg) in full_simp_tac ss THEN' thm_tac dbg end end fun frame_loop_tac side_tac ctxt = let in TRY o ( REPEAT_ALL_NEW (DETERM o frame_step_tac side_tac false ctxt) ) end fun frame_tac side_tac ctxt = let open Refine_Util Conv val frame_rem_thms = @{thms frame_rem_thms} @ Named_Theorems_Rev.get ctxt @{named_theorems_rev sepref_frame_rem_rules} val solve_remainder_tac = TRY o REPEAT_ALL_NEW (DETERM o resolve_tac ctxt frame_rem_thms) in (prepare_frame_tac ctxt THEN' resolve_tac ctxt @{thms ent_star_mono entt_star_mono}) THEN_ALL_NEW_LIST [ frame_loop_tac side_tac ctxt, solve_remainder_tac ] end fun merge_tac side_tac ctxt = let open Refine_Util Conv - val merge_conv = arg1_conv (binop_conv (reorder_ctxt_conv ctxt)) + fun merge_conv ctxt = arg1_conv (binop_conv (reorder_ctxt_conv ctxt)) in CONVERSION Thm.eta_conversion THEN' CONCL_COND' is_merge THEN' simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms star_aci}) THEN' - CONVERSION (HOL_concl_conv (fn _ => merge_conv) ctxt) THEN' + CONVERSION (HOL_concl_conv merge_conv ctxt) THEN' frame_loop_tac side_tac ctxt end val setup = normrel_eqs.setup local open Sepref_Basic fun is_invalid @{mpat "hn_invalid _ _ _ :: assn"} = true | is_invalid _ = false fun contains_invalid @{mpat "Trueprop (RECOVER_PURE ?Q _)"} = exists is_invalid (strip_star Q) | contains_invalid _ = false in fun recover_pure_tac ctxt = CONCL_COND' contains_invalid THEN_ELSE' ( REPEAT_ALL_NEW (DETERM o (resolve_tac ctxt @{thms recover_pure} ORELSE' Sepref_Constraints.constraint_tac ctxt)), resolve_tac ctxt @{thms recover_pure_triv} ) end local open Sepref_Basic Refine_Util datatype cte = Other of term | Hn of term * term * term fun dest_ctxt_elem @{mpat "hn_ctxt ?R ?a ?c"} = Hn (R,a,c) | dest_ctxt_elem t = Other t fun mk_ctxt_elem (Other t) = t | mk_ctxt_elem (Hn (R,a,c)) = @{mk_term "hn_ctxt ?R ?a ?c"} fun match x (Hn (_,y,_)) = x aconv y | match _ _ = false fun dest_with_frame (*ctxt*) _ t = let val (P,c,Q,R,a) = dest_hn_refine t val (_,(_,args)) = dest_hnr_absfun a val pre_ctes = strip_star P |> map dest_ctxt_elem val (pre_args,frame) = (case split_matching match args pre_ctes of NONE => raise TERM("align_conv: Could not match all arguments",[P,a]) | SOME x => x) in ((frame,pre_args),c,Q,R,a) end fun align_goal_conv_aux ctxt t = let val ((frame,pre_args),c,Q,R,a) = dest_with_frame ctxt t val P' = apply2 (list_star o map mk_ctxt_elem) (frame,pre_args) |> mk_star val t' = mk_hn_refine (P',c,Q,R,a) in t' end fun align_rl_conv_aux ctxt t = let val ((frame,pre_args),c,Q,R,a) = dest_with_frame ctxt t val _ = frame = [] orelse raise TERM ("align_rl_conv: Extra preconditions in rule",[t,list_star (map mk_ctxt_elem frame)]) val P' = list_star (map mk_ctxt_elem pre_args) val t' = mk_hn_refine (P',c,Q,R,a) in t' end fun normrel_conv ctxt = let val ss = put_simpset HOL_basic_ss ctxt addsimps normrel_eqs.get ctxt in Simplifier.rewrite ss end in fun align_goal_conv ctxt = f_tac_conv ctxt (align_goal_conv_aux ctxt) star_permute_tac fun norm_goal_pre_conv ctxt = let open Conv - val nr_conv = normrel_conv ctxt + + fun conv ctxt = let + val nr_conv = normrel_conv ctxt + in + hn_refine_conv nr_conv all_conv all_conv all_conv all_conv + end in - HOL_concl_conv (fn _ => hn_refine_conv nr_conv all_conv all_conv all_conv all_conv) ctxt - end + HOL_concl_conv conv ctxt + end fun norm_goal_pre_tac ctxt = CONVERSION (norm_goal_pre_conv ctxt) fun align_rl_conv ctxt = let open Conv - val nr_conv = normrel_conv ctxt + fun conv ctxt = let + val nr_conv = normrel_conv ctxt + in + hn_refine_conv nr_conv all_conv nr_conv nr_conv all_conv + end in HOL_concl_conv (fn ctxt => f_tac_conv ctxt (align_rl_conv_aux ctxt) star_permute_tac) ctxt - then_conv HOL_concl_conv (K (hn_refine_conv nr_conv all_conv nr_conv nr_conv all_conv)) ctxt + then_conv HOL_concl_conv conv ctxt end fun align_goal_tac ctxt = CONCL_COND' is_hn_refine_concl THEN' DCONVERSION ctxt (HOL_concl_conv align_goal_conv ctxt) end fun weaken_post_tac ctxt = TRADE (fn ctxt => resolve_tac ctxt @{thms weaken_hnr_postI} THEN' SOLVED' (REPEAT_ALL_NEW (DETERM o resolve_tac ctxt @{thms weaken_hnr_post weaken_hnr_post_triv})) ) ctxt end \ setup Sepref_Frame.setup method_setup weaken_hnr_post = \Scan.succeed (fn ctxt => SIMPLE_METHOD' (Sepref_Frame.weaken_post_tac ctxt))\ \Convert "hn_invalid" to "hn_ctxt (\_ _. true)" in postcondition of hn_refine goal\ (* TODO: Improper, modifies all h\_ premises that happen to be there. Use tagging to protect! *) method extract_hnr_invalids = ( rule hn_refine_preI, ((drule mod_starD hn_invalidI | elim conjE exE)+)? ) \ \Extract \hn_invalid _ _ _ = true\ preconditions from \hn_refine\ goal.\ lemmas [sepref_frame_normrel_eqs] = the_pure_pure pure_the_pure end diff --git a/thys/Refine_Imperative_HOL/Sepref_Monadify.thy b/thys/Refine_Imperative_HOL/Sepref_Monadify.thy --- a/thys/Refine_Imperative_HOL/Sepref_Monadify.thy +++ b/thys/Refine_Imperative_HOL/Sepref_Monadify.thy @@ -1,308 +1,308 @@ section \Monadify\ theory Sepref_Monadify imports Sepref_Basic Sepref_Id_Op begin text \ In this phase, a monadic program is converted to complete monadic form, that is, computation of compound expressions are made visible as top-level operations in the monad. The monadify process is separated into 2 steps. \begin{enumerate} \item In a first step, eta-expansion is used to add missing operands to operations and combinators. This way, operators and combinators always occur with the same arity, which simplifies further processing. \item In a second step, computation of compound operands is flattened, introducing new bindings for the intermediate values. \end{enumerate} \ definition SP \ \Tag to protect content from further application of arity and combinator equations\ where [simp]: "SP x \ x" lemma SP_cong[cong]: "SP x \ SP x" by simp lemma PR_CONST_cong[cong]: "PR_CONST x \ PR_CONST x" by simp definition RCALL \ \Tag that marks recursive call\ where [simp]: "RCALL D \ D" definition EVAL \ \Tag that marks evaluation of plain expression for monadify phase\ where [simp]: "EVAL x \ RETURN x" text \ Internally, the package first applies rewriting rules from \sepref_monadify_arity\, which use eta-expansion to ensure that every combinator has enough actual parameters. Moreover, this phase will mark recursive calls by the tag @{const RCALL}. Next, rewriting rules from \sepref_monadify_comb\ are used to add @{const EVAL}-tags to plain expressions that should be evaluated in the monad. The @{const EVAL} tags are flattened using a default simproc that generates left-to-right argument order. \ lemma monadify_simps: "Refine_Basic.bind$(RETURN$x)$(\\<^sub>2x. f x) = f x" "EVAL$x \ RETURN$x" by simp_all definition [simp]: "PASS \ RETURN" \ \Pass on value, invalidating old one\ lemma remove_pass_simps: "Refine_Basic.bind$(PASS$x)$(\\<^sub>2x. f x) \ f x" "Refine_Basic.bind$m$(\\<^sub>2x. PASS$x) \ m" by simp_all definition COPY :: "'a \ 'a" \ \Marks required copying of parameter\ where [simp]: "COPY x \ x" lemma RET_COPY_PASS_eq: "RETURN$(COPY$p) = PASS$p" by simp named_theorems_rev sepref_monadify_arity "Sepref.Monadify: Arity alignment equations" named_theorems_rev sepref_monadify_comb "Sepref.Monadify: Combinator equations" ML \ structure Sepref_Monadify = struct local fun cr_var (i,T) = ("v"^string_of_int i, Free ("__v"^string_of_int i,T)) fun lambda2_name n t = let val t = @{mk_term "PROTECT2 ?t DUMMY"} in Term.lambda_name n t end fun bind_args exp0 [] = exp0 | bind_args exp0 ((x,m)::xms) = let val lr = bind_args exp0 xms |> incr_boundvars 1 |> lambda2_name x in @{mk_term "Refine_Basic.bind$?m$?lr"} end fun monadify t = let val (f,args) = Autoref_Tagging.strip_app t val _ = not (is_Abs f) orelse raise TERM ("monadify: higher-order",[t]) val argTs = map fastype_of args (*val args = map monadify args*) val args = map (fn a => @{mk_term "EVAL$?a"}) args (*val fT = fastype_of f val argTs = binder_types fT*) val argVs = tag_list 0 argTs |> map cr_var val res0 = let val x = Autoref_Tagging.list_APP (f,map #2 argVs) in @{mk_term "SP (RETURN$?x)"} end val res = bind_args res0 (argVs ~~ args) in res end fun monadify_conv_aux ctxt ct = case Thm.term_of ct of @{mpat "EVAL$_"} => let fun tac goal_ctxt = simp_tac (put_simpset HOL_basic_ss goal_ctxt addsimps @{thms monadify_simps SP_def}) 1 in (*Refine_Util.monitor_conv "monadify"*) ( Refine_Util.f_tac_conv ctxt (dest_comb #> #2 #> monadify) tac) ct end | t => raise TERM ("monadify_conv",[t]) (*fun extract_comb_conv ctxt = Conv.rewrs_conv (Named_Theorems_Rev.get ctxt @{named_theorems_rev sepref_monadify_evalcomb}) *) in (* val monadify_conv = Conv.top_conv (fn ctxt => Conv.try_conv ( extract_comb_conv ctxt else_conv monadify_conv_aux ctxt ) ) *) val monadify_simproc = Simplifier.make_simproc @{context} "monadify_simproc" {lhss = [Logic.varify_global @{term "EVAL$a"}], proc = K (try o monadify_conv_aux)}; end local open Sepref_Basic fun mark_params t = let val (P,c,Q,R,a) = dest_hn_refine t val pps = strip_star P |> map_filter (dest_hn_ctxt_opt #> map_option #2) fun tr env (t as @{mpat "RETURN$?x"}) = if is_Bound x orelse member (aconv) pps x then @{mk_term env: "PASS$?x"} else t | tr env (t1$t2) = tr env t1 $ tr env t2 | tr env (Abs (x,T,t)) = Abs (x,T,tr (T::env) t) | tr _ t = t val a = tr [] a in mk_hn_refine (P,c,Q,R,a) end in fun mark_params_conv ctxt = Refine_Util.f_tac_conv ctxt (mark_params) (fn goal_ctxt => simp_tac (put_simpset HOL_basic_ss goal_ctxt addsimps @{thms PASS_def}) 1) end local open Sepref_Basic fun dp ctxt (@{mpat "Refine_Basic.bind$(PASS$?p)$(?t' AS\<^sub>p (\_. PROTECT2 _ DUMMY))"}) = let val (t',ps) = let val ((t',rc),ctxt) = dest_lambda_rc ctxt t' val f = case t' of @{mpat "PROTECT2 ?f _"} => f | _ => raise Match val (f,ps) = dp ctxt f val t' = @{mk_term "PROTECT2 ?f DUMMY"} val t' = rc t' in (t',ps) end val dup = member (aconv) ps p val t = if dup then @{mk_term "Refine_Basic.bind$(RETURN$(COPY$?p))$?t'"} else @{mk_term "Refine_Basic.bind$(PASS$?p)$?t'"} in (t,p::ps) end | dp ctxt (t1$t2) = (#1 (dp ctxt t1) $ #1 (dp ctxt t2),[]) | dp ctxt (t as (Abs _)) = (apply_under_lambda (#1 oo dp) ctxt t,[]) | dp _ t = (t,[]) fun dp_conv ctxt = Refine_Util.f_tac_conv ctxt (#1 o dp ctxt) (fn goal_ctxt => ALLGOALS (simp_tac (put_simpset HOL_basic_ss goal_ctxt addsimps @{thms RET_COPY_PASS_eq}))) in fun dup_tac ctxt = CONVERSION (Sepref_Basic.hn_refine_concl_conv_a dp_conv ctxt) end fun arity_tac ctxt = let val arity1_ss = put_simpset HOL_basic_ss ctxt addsimps ((Named_Theorems_Rev.get ctxt @{named_theorems_rev sepref_monadify_arity})) |> Simplifier.add_cong @{thm SP_cong} |> Simplifier.add_cong @{thm PR_CONST_cong} val arity2_ss = put_simpset HOL_basic_ss ctxt addsimps @{thms beta SP_def} in simp_tac arity1_ss THEN' simp_tac arity2_ss end fun comb_tac ctxt = let val comb1_ss = put_simpset HOL_basic_ss ctxt addsimps (Named_Theorems_Rev.get ctxt @{named_theorems_rev sepref_monadify_comb}) (*addsimps (Named_Theorems_Rev.get ctxt @{named_theorems_rev sepref_monadify_evalcomb})*) addsimprocs [monadify_simproc] |> Simplifier.add_cong @{thm SP_cong} |> Simplifier.add_cong @{thm PR_CONST_cong} val comb2_ss = put_simpset HOL_basic_ss ctxt addsimps @{thms SP_def} in simp_tac comb1_ss THEN' simp_tac comb2_ss end (*fun ops_tac ctxt = CONVERSION ( Sepref_Basic.hn_refine_concl_conv_a monadify_conv ctxt)*) fun mark_params_tac ctxt = CONVERSION ( - Refine_Util.HOL_concl_conv (K (mark_params_conv ctxt)) ctxt) + Refine_Util.HOL_concl_conv mark_params_conv ctxt) fun contains_eval @{mpat "Trueprop (hn_refine _ _ _ _ ?a)"} = Term.exists_subterm (fn @{mpat EVAL} => true | _ => false) a | contains_eval t = raise TERM("contains_eval",[t]); fun remove_pass_tac ctxt = simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms remove_pass_simps}) fun monadify_tac dbg ctxt = let open Sepref_Basic in PHASES' [ ("arity", arity_tac, 0), ("comb", comb_tac, 0), (*("ops", ops_tac, 0),*) ("check_EVAL", K (CONCL_COND' (not o contains_eval)), 0), ("mark_params", mark_params_tac, 0), ("dup", dup_tac, 0), ("remove_pass", remove_pass_tac, 0) ] (flag_phases_ctrl dbg) ctxt end end \ lemma dflt_arity[sepref_monadify_arity]: "RETURN \ \\<^sub>2x. SP RETURN$x" "RECT \ \\<^sub>2B x. SP RECT$(\\<^sub>2D x. B$(\\<^sub>2x. RCALL$D$x)$x)$x" "case_list \ \\<^sub>2fn fc l. SP case_list$fn$(\\<^sub>2x xs. fc$x$xs)$l" "case_prod \ \\<^sub>2fp p. SP case_prod$(\\<^sub>2a b. fp$a$b)$p" "case_option \ \\<^sub>2fn fs ov. SP case_option$fn$(\\<^sub>2x. fs$x)$ov" "If \ \\<^sub>2b t e. SP If$b$t$e" "Let \ \\<^sub>2x f. SP Let$x$(\\<^sub>2x. f$x)" by (simp_all only: SP_def APP_def PROTECT2_def RCALL_def) lemma dflt_comb[sepref_monadify_comb]: "\B x. RECT$B$x \ Refine_Basic.bind$(EVAL$x)$(\\<^sub>2x. SP (RECT$B$x))" "\D x. RCALL$D$x \ Refine_Basic.bind$(EVAL$x)$(\\<^sub>2x. SP (RCALL$D$x))" "\fn fc l. case_list$fn$fc$l \ Refine_Basic.bind$(EVAL$l)$(\\<^sub>2l. (SP case_list$fn$fc$l))" "\fp p. case_prod$fp$p \ Refine_Basic.bind$(EVAL$p)$(\\<^sub>2p. (SP case_prod$fp$p))" "\fn fs ov. case_option$fn$fs$ov \ Refine_Basic.bind$(EVAL$ov)$(\\<^sub>2ov. (SP case_option$fn$fs$ov))" "\b t e. If$b$t$e \ Refine_Basic.bind$(EVAL$b)$(\\<^sub>2b. (SP If$b$t$e))" "\x. RETURN$x \ Refine_Basic.bind$(EVAL$x)$(\\<^sub>2x. SP (RETURN$x))" "\x f. Let$x$f \ Refine_Basic.bind$(EVAL$x)$(\\<^sub>2x. (SP Let$x$f))" by (simp_all) lemma dflt_plain_comb[sepref_monadify_comb]: "EVAL$(If$b$t$e) \ Refine_Basic.bind$(EVAL$b)$(\\<^sub>2b. If$b$(EVAL$t)$(EVAL$e))" "EVAL$(case_list$fn$(\\<^sub>2x xs. fc x xs)$l) \ Refine_Basic.bind$(EVAL$l)$(\\<^sub>2l. case_list$(EVAL$fn)$(\\<^sub>2x xs. EVAL$(fc x xs))$l)" "EVAL$(case_prod$(\\<^sub>2a b. fp a b)$p) \ Refine_Basic.bind$(EVAL$p)$(\\<^sub>2p. case_prod$(\\<^sub>2a b. EVAL$(fp a b))$p)" "EVAL$(case_option$fn$(\\<^sub>2x. fs x)$ov) \ Refine_Basic.bind$(EVAL$ov)$(\\<^sub>2ov. case_option$(EVAL$fn)$(\\<^sub>2x. EVAL$(fs x))$ov)" "EVAL $ (Let $ v $ (\\<^sub>2x. f x)) \ (\) $ (EVAL $ v) $ (\\<^sub>2x. EVAL $ (f x))" apply (rule eq_reflection, simp split: list.split prod.split option.split)+ done lemma evalcomb_PR_CONST[sepref_monadify_comb]: "EVAL$(PR_CONST x) \ SP (RETURN$(PR_CONST x))" by simp end