diff --git a/thys/Collections/ICF/tools/Locale_Code.thy b/thys/Collections/ICF/tools/Locale_Code.thy --- a/thys/Collections/ICF/tools/Locale_Code.thy +++ b/thys/Collections/ICF/tools/Locale_Code.thy @@ -1,386 +1,386 @@ section \Code Generation from Locales\ theory Locale_Code imports ICF_Tools Ord_Code_Preproc begin text \ Provides a simple mechanism to prepare code equations for constants stemming from locale interpretations. The usage pattern is as follows: \setup Locale_Code.checkpoint\ is called before a series of interpretations, and afterwards, \setup Locale_Code.prepare\ is called. Afterwards, the code generator will correctly recognize expressions involving terms from the locale interpretation. \ text \Tag to indicate pattern deletion\ definition LC_DEL :: "'a \ unit" where "LC_DEL a \ ()" ML \ signature LOCALE_CODE = sig type pat_eq = cterm * thm list val open_block: theory -> theory val close_block: theory -> theory val del_pat: cterm -> theory -> theory val add_pat_eq: cterm -> thm list -> theory -> theory val lc_decl_eq: thm list -> local_theory -> local_theory val lc_decl_del: term -> local_theory -> local_theory val setup: theory -> theory val get_unf_ss: theory -> simpset val tracing_enabled: bool Unsynchronized.ref end structure Locale_Code :LOCALE_CODE = struct open ICF_Tools val tracing_enabled = Unsynchronized.ref false; type pat_eq = cterm * thm list type block_data = {idx:int, del_pats: cterm list, add_pateqs: pat_eq list} val closed_block = {idx = ~1, del_pats=[], add_pateqs=[]}; fun init_block idx = {idx = idx, del_pats=[], add_pateqs=[]}; fun is_open ({idx,...}:block_data) = idx <> ~1; fun assert_open bd = if is_open bd then () else error "Locale_Code: No open block"; fun assert_closed bd = if is_open bd then error "Locale_Code: Block already open" else (); fun merge_bd (bd1,bd2) = ( if is_open bd1 orelse is_open bd2 then error "Locale_Code: Merge with open block" else (); closed_block ); fun bd_add_del_pats ps {idx,del_pats,add_pateqs} = {idx = idx, del_pats = ps@del_pats, add_pateqs = add_pateqs}; fun bd_add_add_pateqs pes {idx,del_pats,add_pateqs} = {idx = idx, del_pats = del_pats, add_pateqs = pes@add_pateqs}; structure BlockData = Theory_Data ( type T = block_data val empty = (closed_block) val extend = I val merge = merge_bd ); structure FoldSSData = Oc_Simpset ( val prio = 5; val name = "Locale_Code"; ); fun add_unf_thms thms thy = let val ctxt = Proof_Context.init_global thy val thms = map Thm.symmetric thms in FoldSSData.map (fn ss => put_simpset ss ctxt |> sss_add thms |> simpset_of ) thy end val get_unf_ss = FoldSSData.get; (* First order match with fixed head *) fun match_fixed_head (pat,obj) = let (* Match heads *) val inst = Thm.first_order_match (chead_of pat, chead_of obj); val pat = Thm.instantiate_cterm inst pat; (* Now match whole pattern *) val inst = Thm.first_order_match (pat, obj); in inst end; val matches_fixed_head = can match_fixed_head; (* First order match of heads only *) fun match_heads (pat,obj) = Thm.first_order_match (chead_of pat, chead_of obj); val matches_heads = can match_heads; val pat_nargs = Thm.term_of #> strip_comb #> #2 #> length; (* Adjust a theorem to exactly match pattern *) fun norm_thm_pat (thm,pat) = let val thm = norm_def_thm thm; val na_pat = pat_nargs pat; val lhs = Thm.lhs_of thm; val na_lhs = pat_nargs lhs; val lhs' = if na_lhs > na_pat then funpow (na_lhs - na_pat) Thm.dest_fun lhs else lhs; val inst = Thm.first_order_match (lhs',pat); in Thm.instantiate inst thm end; fun del_pat_matches cpat (epat,_) = if pat_nargs cpat = 0 then matches_heads (cpat,epat) else matches_fixed_head (cpat,epat); (* Pattern-Eqs from specification *) local datatype action = ADD of (cterm * thm list) | DEL of cterm fun filter_pat_eq thy thms pat = let val cpat = Thm.global_cterm_of thy pat; in if (pat_nargs cpat = 0) then NONE else let val thms' = fold (fn thm => fn acc => case try norm_thm_pat (thm, cpat) of NONE => acc | SOME thm => thm::acc ) thms []; in case thms' of [] => NONE | _ => SOME (ADD (cpat,thms')) end end; fun process_actions acc [] = acc | process_actions acc (ADD peq::acts) = process_actions (peq::acc) acts | process_actions acc (DEL cpat::acts) = let val acc' = filter (not o curry renames_cterm cpat o fst) acc; val _ = if length acc = length acc' then warning ("Locale_Code: LC_DEL without effect: " ^ @{make_string} cpat) else (); in process_actions acc' acts end; fun pat_eqs_of_spec thy {rough_classification = Spec_Rules.Equational _, terms = pats, rules = thms, ...} = map_filter (filter_pat_eq thy thms) pats | pat_eqs_of_spec thy {rough_classification = Spec_Rules.Unknown, terms = [Const (@{const_name LC_DEL},_)$pat], ...} = [(DEL (Thm.global_cterm_of thy pat))] | pat_eqs_of_spec _ _ = []; in fun pat_eqs_of_specs thy specs = map (pat_eqs_of_spec thy) specs |> flat |> rev |> process_actions []; end; fun is_proper_pat cpat = let val pat = Thm.term_of cpat; val (f,args) = strip_comb pat; in is_Const f andalso args <> [] andalso not (is_Var (hd (rev args))) end; (* Instantiating pattern-eq *) local (* Get constant name for instantiation pattern *) fun inst_name lthy pat = let val (fname,params) = case strip_comb pat of ((Const (fname,_)),params) => (fname,params) | _ => raise TERM ("inst_name: Expected pattern",[pat]); fun pname (Const (n,_)) = Long_Name.base_name n | pname (s$t) = pname s ^ "_" ^ pname t | pname _ = Name.uu; in space_implode "_" (Long_Name.base_name fname::map pname params) |> gen_variant (can (Proof_Context.read_const {proper = true, strict = false} lthy)) end; in fun inst_pat_eq (cpat,thms) = wrap_lthy_result_global (fn lthy => let val ((inst,thms),lthy) = Variable.import true thms lthy; val cpat = Thm.instantiate_cterm inst cpat; val pat = Thm.term_of cpat; val name = inst_name lthy pat; val ((_,(_,def_thm)),lthy) = Local_Theory.define ((Binding.name name,NoSyn), ((Binding.name (Thm.def_name name),[]),pat)) lthy; val thms' = map (Local_Defs.fold lthy [def_thm]) thms; in ((def_thm,thms'),lthy) end) (fn m => fn (def_thm,thms') => (Morphism.thm m def_thm, map (Morphism.thm m) thms')) #> (fn ((def_thm,thms'),thy) => let val thy = thy |> add_unf_thms [def_thm] |> Code.declare_default_eqns_global (map (rpair true) thms'); in thy end) end (* Bookkeeping *) fun new_specs thy = let val bd = BlockData.get thy; val _ = assert_open bd; val ctxt = Proof_Context.init_global thy; val srules = Spec_Rules.get ctxt; val res = take (length srules - #idx bd) srules; in res end fun open_block thy = let val bd = BlockData.get thy; val _ = assert_closed bd; val ctxt = Proof_Context.init_global thy; val idx = length (Spec_Rules.get ctxt); val thy = BlockData.map (K (init_block idx)) thy; in thy end; fun process_block bd thy = let fun filter_del_pats cpat peqs = let val peqs' = filter (not o del_pat_matches cpat) peqs val _ = if length peqs = length peqs' then warning ("Locale_Code: No pattern-eqs matching filter: " ^ @{make_string} cpat) else (); in peqs' end; fun filter_add_pats (orig_pat,_) = forall (fn (add_pat,_) => not (renames_cterm (orig_pat,add_pat))) (#add_pateqs bd); val specs = new_specs thy; val peqs = pat_eqs_of_specs thy specs |> fold filter_del_pats (#del_pats bd) |> filter filter_add_pats; val peqs = peqs @ #add_pateqs bd; val peqs = rev peqs; (* Important: Process equations in the order in that they have been added! *) val _ = if !tracing_enabled then map (fn peq => (tracing (@{make_string} peq); ())) peqs else []; val thy = thy |> fold inst_pat_eq peqs; in thy end; fun close_block thy = let val bd = BlockData.get thy; val _ = assert_open bd; val thy = process_block bd thy |> BlockData.map (K closed_block); in thy end; fun del_pat cpat thy = let val bd = BlockData.get thy; val _ = assert_open bd; val bd = bd_add_del_pats [cpat] bd; val thy = BlockData.map (K bd) thy; in thy end; fun add_pat_eq cpat thms thy = let val _ = is_proper_pat cpat orelse raise CTERM ("add_pat_eq: Not a proper pattern",[cpat]); fun ntp thm = case try norm_thm_pat (thm,cpat) of NONE => raise THM ("add_pat_eq: Theorem does not match pattern",~1,[thm]) | SOME thm => thm; val thms = map ntp thms; val thy = BlockData.map (bd_add_add_pateqs [(cpat,thms)]) thy; in thy end; local fun cpat_of_thm thm = let fun strip ct = case Thm.term_of ct of (_$Var _) => strip (Thm.dest_fun ct) | _ => ct; in strip (Thm.lhs_of thm) end; fun adjust_length (cpat1,cpat2) = let val n1 = cpat1 |> Thm.term_of |> strip_comb |> #2 |> length; val n2 = cpat2 |> Thm.term_of |> strip_comb |> #2 |> length; in if n1>n2 then (funpow (n1-n2) Thm.dest_fun cpat1, cpat2) else (cpat1, funpow (n2-n1) Thm.dest_fun cpat2) end fun find_match cpat cpat' = SOME (cpat,rename_cterm (cpat',cpat)) handle Pattern.MATCH => (case Thm.term_of cpat' of _$_ => find_match (Thm.dest_fun cpat) (Thm.dest_fun cpat') | _ => NONE ); (* Common head of definitional theorems *) fun comp_head thms = case map norm_def_thm thms of [] => NONE | thm::thms => let fun ch [] r = SOME r | ch (thm::thms) (cpat,acc) = let val cpat' = cpat_of_thm thm; val (cpat,cpat') = adjust_length (cpat,cpat') in case find_match cpat cpat' of NONE => NONE | SOME (cpat,inst) => ch thms (cpat, Drule.instantiate_normalize inst thm :: acc) end; in ch thms (cpat_of_thm thm,[thm]) end; in fun lc_decl_eq thms lthy = case comp_head thms of SOME (cpat,thms) => let val _ = if !tracing_enabled then tracing ("decl_eq: " ^ @{make_string} cpat ^ ": " ^ @{make_string} thms) else (); fun decl m = let val cpat'::thms' = Morphism.fact m (Drule.mk_term cpat :: thms); val cpat' = Drule.dest_term cpat'; in Context.mapping (BlockData.map (bd_add_add_pateqs [(cpat',thms')])) I end in lthy |> Local_Theory.declaration {syntax = false, pervasive = false} decl end | NONE => raise THM ("Locale_Code.lc_decl_eq: No common pattern",~1,thms); end; fun lc_decl_del pat = let val ty = fastype_of pat; val dpat = Const (@{const_name LC_DEL},ty --> @{typ unit})$pat; in - Spec_Rules.add "" Spec_Rules.Unknown [dpat] [] + Spec_Rules.add Binding.empty Spec_Rules.Unknown [dpat] [] end (* Package setup *) val setup = FoldSSData.setup; end \ setup Locale_Code.setup attribute_setup lc_delete = \ Parse.and_list1' ICF_Tools.parse_cpat >> (fn cpats => Thm.declaration_attribute (K (Context.mapping (fold Locale_Code.del_pat cpats) I))) \ "Locale_Code: Delete patterns for current block" attribute_setup lc_add = \ Parse.and_list1' (ICF_Tools.parse_cpat -- Attrib.thms) >> (fn peqs => Thm.declaration_attribute (K (Context.mapping (fold (uncurry Locale_Code.add_pat_eq) peqs) I))) \ "Locale_Code: Add pattern-eqs for current block" end diff --git a/thys/Nominal2/nominal_termination.ML b/thys/Nominal2/nominal_termination.ML --- a/thys/Nominal2/nominal_termination.ML +++ b/thys/Nominal2/nominal_termination.ML @@ -1,115 +1,115 @@ (* Nominal Termination Author: Christian Urban heavily based on the code of Alexander Krauss (code forked on 18 July 2011) Redefinition of the termination command *) signature NOMINAL_FUNCTION_TERMINATION = sig include NOMINAL_FUNCTION_DATA val termination : bool -> term option -> local_theory -> Proof.state val termination_cmd : bool -> string option -> local_theory -> Proof.state end structure Nominal_Function_Termination : NOMINAL_FUNCTION_TERMINATION = struct open Function_Lib open Function_Common open Nominal_Function_Common val simp_attribs = map (Attrib.internal o K) [Simplifier.simp_add, Named_Theorems.add @{named_theorems nitpick_simp}] val eqvt_attrib = Attrib.internal (K Nominal_ThmDecls.eqvt_add) fun prepare_termination_proof prep_term is_eqvt raw_term_opt lthy = let val term_opt = Option.map (prep_term lthy) raw_term_opt val info = the (case term_opt of SOME t => (import_function_data t lthy handle Option.Option => error ("Not a function: " ^ quote (Syntax.string_of_term lthy t))) | NONE => (import_last_function lthy handle Option.Option => error "Not a function")) val { termination, fs, R, add_simps, case_names, psimps, pinducts, defname, eqvts, ...} = info val domT = domain_type (fastype_of R) val goal = HOLogic.mk_Trueprop (HOLogic.mk_all ("x", domT, mk_acc domT R $ Free ("x", domT))) fun afterqed [[totality]] lthy = let val totality = Thm.close_derivation \<^here> totality val remove_domain_condition = full_simplify (put_simpset HOL_basic_ss lthy addsimps [totality, @{thm True_implies_equals}]) val tsimps = map remove_domain_condition psimps val tinducts = map remove_domain_condition pinducts val teqvts = map remove_domain_condition eqvts fun qualify n = Binding.name n |> Binding.qualify true defname in lthy |> add_simps I "simps" I simp_attribs tsimps ||>> Local_Theory.note ((qualify "eqvt", if is_eqvt then [eqvt_attrib] else []), teqvts) ||>> Local_Theory.note ((qualify "induct", [Attrib.internal (K (Rule_Cases.case_names case_names))]), tinducts) |-> (fn ((simps, (_, eqvts)), (_, inducts)) => fn lthy => let val info' = { is_partial=false, defname=defname, add_simps=add_simps, case_names=case_names, fs=fs, R=R, psimps=psimps, pinducts=pinducts, simps=SOME simps, inducts=SOME inducts, termination=termination, eqvts=teqvts } in lthy |> Local_Theory.declaration {syntax = false, pervasive = false} (add_function_data o morph_function_data info') - |> Spec_Rules.add "" Spec_Rules.equational_recdef fs tsimps + |> Spec_Rules.add Binding.empty Spec_Rules.equational_recdef fs tsimps |> Code.declare_default_eqns (map (rpair true) tsimps) |> pair info' end) end in (goal, afterqed, termination) end fun gen_termination prep_term is_eqvt raw_term_opt lthy = let val (goal, afterqed, termination) = prepare_termination_proof prep_term is_eqvt raw_term_opt lthy in lthy |> Proof_Context.note_thmss "" [((Binding.empty, [Context_Rules.rule_del]), [([allI], [])])] |> snd |> Proof_Context.note_thmss "" [((Binding.empty, [Context_Rules.intro_bang (SOME 1)]), [([allI], [])])] |> snd |> Proof_Context.note_thmss "" [((Binding.name "termination", [Context_Rules.intro_bang (SOME 0)]), [([Goal.norm_result lthy termination], [])])] |> snd |> Proof.theorem NONE (snd oo afterqed) [[(goal, [])]] end val termination = gen_termination Syntax.check_term val termination_cmd = gen_termination Syntax.read_term (* outer syntax *) val option_parser = (Scan.optional (@{keyword "("} |-- Parse.!!! ((Parse.reserved "eqvt" >> K true) || (Parse.reserved "no_eqvt" >> K false)) --| @{keyword ")"}) (false)) val _ = Outer_Syntax.local_theory_to_proof @{command_keyword nominal_termination} "prove termination of a recursive nominal function" (option_parser -- Scan.option Parse.term >> (fn (is_eqvt, opt_trm) => termination_cmd is_eqvt opt_trm)) end diff --git a/thys/Partial_Function_MR/partial_function_mr.ML b/thys/Partial_Function_MR/partial_function_mr.ML --- a/thys/Partial_Function_MR/partial_function_mr.ML +++ b/thys/Partial_Function_MR/partial_function_mr.ML @@ -1,339 +1,339 @@ (* Author: Rene Thiemann, License: LGPL *) signature PARTIAL_FUNCTION_MR = sig val init: string -> (* make monad_map: monad term * funs * monad as typ * monad bs typ * a->b typs -> map_monad funs monad term *) (term * term list * typ * typ * typ list -> term) -> (* make monad type: fixed and flexible types *) (typ list * typ list -> typ) -> (* destruct monad type: fixed and flexible types *) (typ -> typ list * typ list) -> (* monad_map_compose thm: mapM f (mapM g x) = mapM (f o g) x *) thm list -> (* monad_map_ident thm: mapM (% y. y) x = x *) thm list -> declaration val add_partial_function_mr: string -> (binding * typ option * mixfix) list -> Specification.multi_specs -> local_theory -> thm list * local_theory val add_partial_function_mr_cmd: string -> (binding * string option * mixfix) list -> Specification.multi_specs_cmd -> local_theory -> thm list * local_theory end; structure Partial_Function_MR: PARTIAL_FUNCTION_MR = struct val partial_function_mr_trace = Attrib.setup_config_bool @{binding partial_function_mr_trace} (K false); fun trace ctxt msg = if Config.get ctxt partial_function_mr_trace then tracing msg else () datatype setup_data = Setup_Data of {mk_monad_map: term * term list * typ * typ * typ list -> term, mk_monadT: typ list * typ list -> typ, dest_monadT: typ -> typ list * typ list, monad_map_comp: thm list, monad_map_id: thm list}; (* the following code has been copied from partial_function.ML *) structure Modes = Generic_Data ( type T = setup_data Symtab.table; val empty = Symtab.empty; val extend = I; fun merge data = Symtab.merge (K true) data; ) val known_modes = Symtab.keys o Modes.get o Context.Proof; val lookup_mode = Symtab.lookup o Modes.get o Context.Proof; fun curry_const (A, B, C) = Const (@{const_name Product_Type.curry}, [HOLogic.mk_prodT (A, B) --> C, A, B] ---> C); fun mk_curry f = case fastype_of f of Type ("fun", [Type (_, [S, T]), U]) => curry_const (S, T, U) $ f | T => raise TYPE ("mk_curry", [T], [f]); fun curry_n arity = funpow (arity - 1) mk_curry; fun uncurry_n arity = funpow (arity - 1) HOLogic.mk_case_prod; (* end copy of partial_function.ML *) fun init mode mk_monad_map mk_monadT dest_monadT monad_map_comp monad_map_id phi = let val thm = Morphism.thm phi; (* TODO: are there morphisms required on mk_monad_map???, ... *) val data' = Setup_Data {mk_monad_map=mk_monad_map, mk_monadT=mk_monadT, dest_monadT=dest_monadT, monad_map_comp=map thm monad_map_comp,monad_map_id=map thm monad_map_id}; in Modes.map (Symtab.update (mode, data')) end fun mk_sumT (T1,T2) = Type (@{type_name sum}, [T1,T2]) fun mk_choiceT [ty] = ty | mk_choiceT (ty :: more) = mk_sumT (ty,mk_choiceT more) | mk_choiceT _ = error "mk_choiceT []" fun mk_choice_resT mk_monadT dest_monadT mTs = let val (commonTs,argTs) = map dest_monadT mTs |> split_list |> apfst hd; val n = length (hd argTs); val new = map (fn i => mk_choiceT (map (fn xs => nth xs i) argTs)) (0 upto (n - 1)) in mk_monadT (commonTs,new) end; fun mk_inj [_] t _ = t | mk_inj (ty :: more) t n = let val moreT = mk_choiceT more; val allT = mk_sumT (ty,moreT) in if n = 0 then Const (@{const_name Inl}, ty --> allT) $ t else Const (@{const_name Inr}, moreT --> allT) $ mk_inj more t (n-1) end | mk_inj _ _ _ = error "mk_inj [] _ _" fun mk_proj [_] t _ = t | mk_proj (ty :: more) t n = let val moreT = mk_choiceT more; val allT = mk_sumT (ty,moreT) in if n = 0 then Const (@{const_name Sum_Type.projl}, allT --> ty) $ t else mk_proj more (Const (@{const_name Sum_Type.projr}, allT --> moreT) $ t) (n-1) end | mk_proj _ _ _ = error "mk_proj [] _ _" fun get_head ctxt (_,(_,eqn)) = let val ((_, plain_eqn), _) = Variable.focus NONE eqn ctxt; val lhs = HOLogic.dest_eq (HOLogic.dest_Trueprop plain_eqn) |> #1; val head = strip_comb lhs |> #1; in head end; fun get_infos lthy heads (fix,(_,eqn)) = let val ((_, plain_eqn), _) = Variable.focus NONE eqn lthy; val ((f_binding, fT), mixfix) = fix; val fname = Binding.name_of f_binding; val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop plain_eqn); val (_, args) = strip_comb lhs; val F = fold_rev lambda (heads @ args) rhs; val arity = length args; val (aTs, bTs) = chop arity (binder_types fT); val tupleT = foldl1 HOLogic.mk_prodT aTs; val fT_uc = tupleT :: bTs ---> body_type fT; val (inT,resT) = dest_funT fT_uc; val f_uc = Free (fname, fT_uc); val f_cuc = curry_n arity f_uc in (fname, f_cuc, f_uc, inT, resT, ((f_binding,mixfix),fT), F, arity, args) end; fun fresh_var ctxt name = Name.variant name (Variable.names_of ctxt) |> #1 (* partial_function_mr definition *) fun gen_add_partial_function_mr prep mode fixes_raw eqns_raw lthy = let val setup_data = the (lookup_mode lthy mode) handle Option.Option => error (cat_lines ["Unknown mode " ^ quote mode ^ ".", "Known modes are " ^ commas_quote (known_modes lthy) ^ "."]); val Setup_Data {mk_monad_map, mk_monadT, dest_monadT, monad_map_comp, monad_map_id} = setup_data; val _ = if length eqns_raw < 2 then error "require at least two function definitions" else (); val ((fixes, eq_abinding_eqns), _) = prep fixes_raw eqns_raw lthy; val _ = if length eqns_raw = length fixes then () else error "# of eqns does not match # of constants"; val fix_eq_abinding_eqns = fixes ~~ eq_abinding_eqns; val heads = map (get_head lthy) fix_eq_abinding_eqns; val fnames = map (Binding.name_of o #1 o #1) fixes val fnames' = map (#1 o Term.dest_Free) heads val f_f = fnames ~~ fnames' val _ = case find_first (fn (f,g) => not (f = g)) f_f of NONE => () | SOME _ => error ("list of function symbols does not match list of equations:\n" ^ @{make_string} fnames ^ "\nvs\n" ^ @{make_string} fnames') val all = map (get_infos lthy heads) fix_eq_abinding_eqns val f_cucs = map #2 all val f_ucs = map #3 all val inTs = map #4 all val resTs = map #5 all val bindings_types = map #6 all val Fs = map #7 all val arities = map #8 all val all_args = map #9 all val glob_inT = mk_choiceT inTs val glob_resT = mk_choice_resT mk_monadT dest_monadT resTs val inj = mk_inj inTs val glob_fname = fresh_var lthy (foldl1 (fn (a,b) => a ^ "_" ^ b) (fnames @ [serial_string ()])) val glob_constT = glob_inT --> glob_resT; val glob_const = Free (glob_fname, glob_constT) val nums = 0 upto (length all - 1) fun mk_res_inj_proj n = let val resT = nth resTs n val glob_Targs = dest_monadT glob_resT |> #2 val res_Targs = dest_monadT resT |> #2 val m = length res_Targs fun inj_proj m = let val resTs_m = map (fn resT => nth (dest_monadT resT |> #2) m) resTs val resT_arg = nth resTs_m n val globT_arg = nth glob_Targs m val x = Free ("x",resT_arg) val y = Free ("x",globT_arg) val inj = lambda x (mk_inj resTs_m x n) val proj = lambda y (mk_proj resTs_m y n) in ((inj, resT_arg --> globT_arg), (proj, globT_arg --> resT_arg)) end; val (inj,proj) = map inj_proj (0 upto (m - 1)) |> split_list val (t_to_ss_inj,t_to_sTs_inj) = split_list inj; val (t_to_ss_proj,t_to_sTs_proj) = split_list proj; in (fn mt => mk_monad_map (mt, t_to_ss_inj, resT, glob_resT, t_to_sTs_inj), fn mt => mk_monad_map (mt, t_to_ss_proj, glob_resT, resT, t_to_sTs_proj)) end; val (res_inj, res_proj) = map mk_res_inj_proj nums |> split_list fun mk_global_fun n = let val fname = nth fnames n val inT = nth inTs n val xs = Free (fresh_var lthy ("x_" ^ fname), inT) val inj_xs = inj xs n val glob_inj_xs = glob_const $ inj_xs val glob_inj_xs_map = nth res_proj n glob_inj_xs val res = lambda xs glob_inj_xs_map in (xs,res) end val (xss,global_funs) = map mk_global_fun nums |> split_list fun mk_cases n = let val xs = nth xss n val F = nth Fs n; val arity = nth arities n; val F_uc = fold_rev lambda f_ucs (uncurry_n arity (list_comb (F, f_cucs))); val F_uc_inst = Term.betapplys (F_uc,global_funs) val res = lambda xs (nth res_inj n (F_uc_inst $ xs)) in res end; val all_cases = map mk_cases nums; fun combine_cases [cs] [_] = cs | combine_cases (cs :: more) (inT :: moreTy) = let val moreT = mk_choiceT moreTy val sumT = mk_sumT (inT, moreT) val case_const = Const (@{const_name case_sum}, (inT --> glob_resT) --> (moreT --> glob_resT) --> sumT --> glob_resT) in case_const $ cs $ combine_cases more moreTy end | combine_cases _ _ = error "combine_cases with incompatible argument lists"; val glob_x_name = fresh_var lthy ("x_" ^ glob_fname) val glob_x = Free (glob_x_name,glob_inT) val rhs = combine_cases all_cases inTs $ glob_x; val lhs = glob_const $ glob_x val eq = HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs,rhs)) val glob_binding = Binding.name (glob_fname) |> Binding.concealed val glob_attrib_binding = Binding.empty_atts val _ = trace lthy "invoking partial_function on global function" val priv_lthy = lthy |> Proof_Context.private_scope (Binding.new_scope()) val ((glob_const, glob_simp_thm),priv_lthy') = priv_lthy |> Partial_Function.add_partial_function mode [(glob_binding,SOME glob_constT,NoSyn)] (glob_attrib_binding,eq) val glob_lthy = priv_lthy' |> Proof_Context.restore_naming lthy val _ = trace lthy "deriving simp rules for separate functions from global function" fun define_f n (fs, fdefs,rhss,lthy) = let val ((fbinding,mixfix),_) = nth bindings_types n val fname = nth fnames n val inT = nth inTs n; val arity = nth arities n; val x = Free (fresh_var lthy ("x_" ^ fname), inT) val inj_argsProd = inj x n val call = glob_const $ inj_argsProd val post = nth res_proj n call val rhs = curry_n arity (lambda x post) val ((f, (_, f_def)),lthy') = Local_Theory.define_internal ((fbinding,mixfix), (Binding.empty_atts, rhs)) lthy in (f :: fs, f_def :: fdefs,rhs :: rhss,lthy') end val (fs,fdefs,f_rhss,local_lthy) = fold_rev define_f nums ([],[],[],glob_lthy) val glob_simp_thm' = let fun mk_case_new n = let val F = nth Fs n val arity = nth arities n val Finst = uncurry_n arity (Term.betapplys (F,fs)) val xs = nth xss n val res = lambda xs (nth res_inj n (Finst $ xs)) in res end; val new_cases = map mk_case_new nums; val rhs = combine_cases new_cases inTs $ glob_x; val lhs = glob_const $ glob_x val eq = HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs,rhs)) in Goal.prove local_lthy [glob_x_name] [] eq (fn {prems = _, context = ctxt} => Thm.instantiate' [] [SOME (Thm.cterm_of ctxt glob_x)] glob_simp_thm |> (fn simp_thm => unfold_tac ctxt [simp_thm] THEN unfold_tac ctxt fdefs)) end fun mk_simp_thm n = let val args = nth all_args n val arg_names = map (dest_Free #> fst) args val f = nth fs n val F = nth Fs n val fdef = nth fdefs n val lhs = list_comb (f,args); val mhs = Term.betapplys (nth f_rhss n, args) val rhs = list_comb (list_comb (F,fs), args); val eq1 = HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs,mhs)) val eq2 = HOLogic.mk_Trueprop (HOLogic.mk_eq (mhs,rhs)) val simp_thm1 = Goal.prove local_lthy arg_names [] eq1 (fn {prems = _, context = ctxt} => unfold_tac ctxt [fdef]) val simp_thm2 = Goal.prove local_lthy arg_names [] eq2 (fn {prems = _, context = ctxt} => unfold_tac ctxt [glob_simp_thm'] THEN unfold_tac ctxt @{thms sum.simps curry_def split} THEN unfold_tac ctxt (@{thm o_def} :: monad_map_comp) THEN unfold_tac ctxt (monad_map_id @ @{thms sum.sel})) in @{thm trans} OF [simp_thm1,simp_thm2] end val simp_thms = map mk_simp_thm nums fun register n lthy = let val simp_thm = nth simp_thms n val eq_abinding = nth eq_abinding_eqns n |> fst val fname = nth fnames n val f = nth fs n in lthy |> Local_Theory.note (eq_abinding, [simp_thm]) |-> (fn (_, simps) => - Spec_Rules.add "" Spec_Rules.equational_recdef [f] simps + Spec_Rules.add Binding.empty Spec_Rules.equational_recdef [f] simps #> Local_Theory.note ((Binding.qualify true fname (Binding.name "simps"), @{attributes [code]}), simps) #>> snd #>> hd) end in fold (fn i => fn (simps, lthy) => case register i lthy of (simp, lthy') => (simps @ [simp], lthy')) nums ([], local_lthy) end; val add_partial_function_mr = gen_add_partial_function_mr Specification.check_multi_specs; val add_partial_function_mr_cmd = gen_add_partial_function_mr Specification.read_multi_specs; val mode = @{keyword "("} |-- Parse.name --| @{keyword ")"}; val _ = Outer_Syntax.local_theory @{command_keyword partial_function_mr} "define mutually recursive partial functions" (mode -- Parse_Spec.specification >> (fn (mode, (fixes, specs)) => add_partial_function_mr_cmd mode fixes specs #> #2)); end