diff --git a/src/HOL/Tools/BNF/bnf_fp_util.ML b/src/HOL/Tools/BNF/bnf_fp_util.ML --- a/src/HOL/Tools/BNF/bnf_fp_util.ML +++ b/src/HOL/Tools/BNF/bnf_fp_util.ML @@ -1,1027 +1,984 @@ (* Title: HOL/Tools/BNF/bnf_fp_util.ML Author: Dmitriy Traytel, TU Muenchen Author: Jasmin Blanchette, TU Muenchen Author: Martin Desharnais, TU Muenchen Author: Stefan Berghofer, TU Muenchen Copyright 2012, 2013, 2014 Shared library for the datatype and codatatype constructions. *) signature BNF_FP_UTIL = sig exception EMPTY_DATATYPE of string type fp_result = {Ts: typ list, bnfs: BNF_Def.bnf list, pre_bnfs: BNF_Def.bnf list, absT_infos: BNF_Comp.absT_info list, ctors: term list, dtors: term list, xtor_un_folds: term list, xtor_co_recs: term list, xtor_co_induct: thm, dtor_ctors: thm list, ctor_dtors: thm list, ctor_injects: thm list, dtor_injects: thm list, xtor_maps: thm list, xtor_map_unique: thm, xtor_setss: thm list list, xtor_rels: thm list, xtor_un_fold_thms: thm list, xtor_co_rec_thms: thm list, xtor_un_fold_unique: thm, xtor_co_rec_unique: thm, xtor_un_fold_o_maps: thm list, xtor_co_rec_o_maps: thm list, xtor_un_fold_transfers: thm list, xtor_co_rec_transfers: thm list, xtor_rel_co_induct: thm, dtor_set_inducts: thm list} val morph_fp_result: morphism -> fp_result -> fp_result val time: Proof.context -> Timer.real_timer -> string -> Timer.real_timer val fixpoint: ('a * 'a -> bool) -> ('a list -> 'a list) -> 'a list -> 'a list val IITN: string val LevN: string val algN: string val behN: string val bisN: string val carTN: string val caseN: string val coN: string val coinductN: string val coinduct_strongN: string val corecN: string val corec_discN: string val corec_disc_iffN: string val ctorN: string val ctor_dtorN: string val ctor_exhaustN: string val ctor_induct2N: string val ctor_inductN: string val ctor_injectN: string val ctor_foldN: string val ctor_fold_o_mapN: string val ctor_fold_transferN: string val ctor_fold_uniqueN: string val ctor_mapN: string val ctor_map_uniqueN: string val ctor_recN: string val ctor_rec_o_mapN: string val ctor_rec_transferN: string val ctor_rec_uniqueN: string val ctor_relN: string val ctor_rel_inductN: string val ctor_set_inclN: string val ctor_set_set_inclN: string val dtorN: string val dtor_coinductN: string val dtor_corecN: string val dtor_corec_o_mapN: string val dtor_corec_transferN: string val dtor_corec_uniqueN: string val dtor_ctorN: string val dtor_exhaustN: string val dtor_injectN: string val dtor_mapN: string val dtor_map_coinductN: string val dtor_map_coinduct_strongN: string val dtor_map_uniqueN: string val dtor_relN: string val dtor_rel_coinductN: string val dtor_set_inclN: string val dtor_set_set_inclN: string val dtor_coinduct_strongN: string val dtor_unfoldN: string val dtor_unfold_o_mapN: string val dtor_unfold_transferN: string val dtor_unfold_uniqueN: string val exhaustN: string val colN: string val inductN: string val injectN: string val isNodeN: string val lsbisN: string val mapN: string val map_uniqueN: string val min_algN: string val morN: string val nchotomyN: string val recN: string val rel_casesN: string val rel_coinductN: string val rel_inductN: string val rel_injectN: string val rel_introsN: string val rel_distinctN: string val rel_selN: string val rvN: string val corec_selN: string val set_inclN: string val set_set_inclN: string val setN: string val simpsN: string val strTN: string val str_initN: string val sum_bdN: string val sum_bdTN: string val uniqueN: string (* TODO: Don't index set facts. Isabelle packages traditionally generate uniform names. *) val mk_ctor_setN: int -> string val mk_dtor_setN: int -> string val mk_dtor_set_inductN: int -> string val mk_set_inductN: int -> string val co_prefix: BNF_Util.fp_kind -> string val split_conj_thm: thm -> thm list val split_conj_prems: int -> thm -> thm val mk_sumTN: typ list -> typ val mk_sumTN_balanced: typ list -> typ val mk_tupleT_balanced: typ list -> typ val mk_sumprodT_balanced: typ list list -> typ val mk_proj: typ -> int -> int -> term val mk_convol: term * term -> term val mk_rel_prod: term -> term -> term val mk_rel_sum: term -> term -> term val Inl_const: typ -> typ -> term val Inr_const: typ -> typ -> term val mk_tuple_balanced: term list -> term val mk_tuple1_balanced: typ list -> term list -> term val abs_curried_balanced: typ list -> term -> term val mk_tupled_fun: term -> term -> term list -> term val mk_case_sum: term * term -> term val mk_case_sumN: term list -> term val mk_case_absumprod: typ -> term -> term list -> term list list -> term list list -> term val mk_Inl: typ -> term -> term val mk_Inr: typ -> term -> term val mk_sumprod_balanced: typ -> int -> int -> term list -> term val mk_absumprod: typ -> term -> int -> int -> term list -> term val dest_sumT: typ -> typ * typ val dest_sumTN_balanced: int -> typ -> typ list val dest_tupleT_balanced: int -> typ -> typ list val dest_absumprodT: typ -> typ -> int -> int list -> typ -> typ list list val If_const: typ -> term val mk_Field: term -> term val mk_If: term -> term -> term -> term val mk_absumprodE: thm -> int list -> thm val mk_sum_caseN: int -> int -> thm val mk_sum_caseN_balanced: int -> int -> thm val mk_sum_Cinfinite: thm list -> thm val mk_sum_card_order: thm list -> thm val force_typ: Proof.context -> typ -> term -> term val mk_xtor_rel_co_induct_thm: BNF_Util.fp_kind -> term list -> term list -> term list -> term list -> term list -> term list -> term list -> term list -> ({prems: thm list, context: Proof.context} -> tactic) -> Proof.context -> thm val mk_xtor_co_iter_transfer_thms: BNF_Util.fp_kind -> term list -> term list -> term list -> term list -> term list -> term list -> term list -> ({prems: thm list, context: Proof.context} -> tactic) -> Proof.context -> thm list val mk_xtor_co_iter_o_map_thms: BNF_Util.fp_kind -> bool -> int -> thm -> thm list -> thm list -> thm list -> thm list -> thm list val derive_xtor_co_recs: BNF_Util.fp_kind -> binding list -> (typ list -> typ list) -> (typ list list * typ list) -> BNF_Def.bnf list -> term list -> term list -> thm -> thm list -> thm list -> thm list -> thm list -> (BNF_Comp.absT_info * BNF_Comp.absT_info) option list -> local_theory -> (term list * (thm list * thm * thm list * thm list)) * local_theory val raw_qualify: (binding -> binding) -> binding -> binding -> binding val fixpoint_bnf: bool -> (binding -> binding) -> (binding list -> (string * sort) list -> typ list * typ list list -> BNF_Def.bnf list -> BNF_Comp.absT_info list -> local_theory -> 'a) -> binding list -> (string * sort) list -> (string * sort) list -> (string * sort) list -> typ list -> BNF_Comp.comp_cache -> local_theory -> ((BNF_Def.bnf list * BNF_Comp.absT_info list) * BNF_Comp.comp_cache) * 'a end; structure BNF_FP_Util : BNF_FP_UTIL = struct open Ctr_Sugar open BNF_Comp open BNF_Def open BNF_Util open BNF_FP_Util_Tactics exception EMPTY_DATATYPE of string; type fp_result = {Ts: typ list, bnfs: bnf list, pre_bnfs: BNF_Def.bnf list, absT_infos: BNF_Comp.absT_info list, ctors: term list, dtors: term list, xtor_un_folds: term list, xtor_co_recs: term list, xtor_co_induct: thm, dtor_ctors: thm list, ctor_dtors: thm list, ctor_injects: thm list, dtor_injects: thm list, xtor_maps: thm list, xtor_map_unique: thm, xtor_setss: thm list list, xtor_rels: thm list, xtor_un_fold_thms: thm list, xtor_co_rec_thms: thm list, xtor_un_fold_unique: thm, xtor_co_rec_unique: thm, xtor_un_fold_o_maps: thm list, xtor_co_rec_o_maps: thm list, xtor_un_fold_transfers: thm list, xtor_co_rec_transfers: thm list, xtor_rel_co_induct: thm, dtor_set_inducts: thm list}; fun morph_fp_result phi {Ts, bnfs, pre_bnfs, absT_infos, ctors, dtors, xtor_un_folds, xtor_co_recs, xtor_co_induct, dtor_ctors, ctor_dtors, ctor_injects, dtor_injects, xtor_maps, xtor_map_unique, xtor_setss, xtor_rels, xtor_un_fold_thms, xtor_co_rec_thms, xtor_un_fold_unique, xtor_co_rec_unique, xtor_un_fold_o_maps, xtor_co_rec_o_maps, xtor_un_fold_transfers, xtor_co_rec_transfers, xtor_rel_co_induct, dtor_set_inducts} = {Ts = map (Morphism.typ phi) Ts, bnfs = map (morph_bnf phi) bnfs, pre_bnfs = map (morph_bnf phi) pre_bnfs, absT_infos = map (morph_absT_info phi) absT_infos, ctors = map (Morphism.term phi) ctors, dtors = map (Morphism.term phi) dtors, xtor_un_folds = map (Morphism.term phi) xtor_un_folds, xtor_co_recs = map (Morphism.term phi) xtor_co_recs, xtor_co_induct = Morphism.thm phi xtor_co_induct, dtor_ctors = map (Morphism.thm phi) dtor_ctors, ctor_dtors = map (Morphism.thm phi) ctor_dtors, ctor_injects = map (Morphism.thm phi) ctor_injects, dtor_injects = map (Morphism.thm phi) dtor_injects, xtor_maps = map (Morphism.thm phi) xtor_maps, xtor_map_unique = Morphism.thm phi xtor_map_unique, xtor_setss = map (map (Morphism.thm phi)) xtor_setss, xtor_rels = map (Morphism.thm phi) xtor_rels, xtor_un_fold_thms = map (Morphism.thm phi) xtor_un_fold_thms, xtor_co_rec_thms = map (Morphism.thm phi) xtor_co_rec_thms, xtor_un_fold_unique = Morphism.thm phi xtor_un_fold_unique, xtor_co_rec_unique = Morphism.thm phi xtor_co_rec_unique, xtor_un_fold_o_maps = map (Morphism.thm phi) xtor_un_fold_o_maps, xtor_co_rec_o_maps = map (Morphism.thm phi) xtor_co_rec_o_maps, xtor_un_fold_transfers = map (Morphism.thm phi) xtor_un_fold_transfers, xtor_co_rec_transfers = map (Morphism.thm phi) xtor_co_rec_transfers, xtor_rel_co_induct = Morphism.thm phi xtor_rel_co_induct, dtor_set_inducts = map (Morphism.thm phi) dtor_set_inducts}; fun time ctxt timer msg = (if Config.get ctxt bnf_timing then warning (msg ^ ": " ^ string_of_int (Time.toMilliseconds (Timer.checkRealTimer timer)) ^ " ms") else (); Timer.startRealTimer ()); val preN = "pre_" val rawN = "raw_" val coN = "co" val unN = "un" val algN = "alg" val IITN = "IITN" val foldN = "fold" val unfoldN = unN ^ foldN val uniqueN = "unique" val transferN = "transfer" val simpsN = "simps" val ctorN = "ctor" val dtorN = "dtor" val ctor_foldN = ctorN ^ "_" ^ foldN val dtor_unfoldN = dtorN ^ "_" ^ unfoldN val ctor_fold_uniqueN = ctor_foldN ^ "_" ^ uniqueN val ctor_fold_o_mapN = ctor_foldN ^ "_o_" ^ mapN val dtor_unfold_uniqueN = dtor_unfoldN ^ "_" ^ uniqueN val dtor_unfold_o_mapN = dtor_unfoldN ^ "_o_" ^ mapN val ctor_fold_transferN = ctor_foldN ^ "_" ^ transferN val dtor_unfold_transferN = dtor_unfoldN ^ "_" ^ transferN val ctor_mapN = ctorN ^ "_" ^ mapN val dtor_mapN = dtorN ^ "_" ^ mapN val map_uniqueN = mapN ^ "_" ^ uniqueN val ctor_map_uniqueN = ctorN ^ "_" ^ map_uniqueN val dtor_map_uniqueN = dtorN ^ "_" ^ map_uniqueN val min_algN = "min_alg" val morN = "mor" val bisN = "bis" val lsbisN = "lsbis" val sum_bdTN = "sbdT" val sum_bdN = "sbd" val carTN = "carT" val strTN = "strT" val isNodeN = "isNode" val LevN = "Lev" val rvN = "recover" val behN = "beh" val setN = "set" val mk_ctor_setN = prefix (ctorN ^ "_") o mk_setN val mk_dtor_setN = prefix (dtorN ^ "_") o mk_setN fun mk_set_inductN i = mk_setN i ^ "_induct" val mk_dtor_set_inductN = prefix (dtorN ^ "_") o mk_set_inductN val str_initN = "str_init" val recN = "rec" val corecN = coN ^ recN val ctor_recN = ctorN ^ "_" ^ recN val ctor_rec_o_mapN = ctor_recN ^ "_o_" ^ mapN val ctor_rec_transferN = ctor_recN ^ "_" ^ transferN val ctor_rec_uniqueN = ctor_recN ^ "_" ^ uniqueN val dtor_corecN = dtorN ^ "_" ^ corecN val dtor_corec_o_mapN = dtor_corecN ^ "_o_" ^ mapN val dtor_corec_transferN = dtor_corecN ^ "_" ^ transferN val dtor_corec_uniqueN = dtor_corecN ^ "_" ^ uniqueN val ctor_dtorN = ctorN ^ "_" ^ dtorN val dtor_ctorN = dtorN ^ "_" ^ ctorN val nchotomyN = "nchotomy" val injectN = "inject" val exhaustN = "exhaust" val ctor_injectN = ctorN ^ "_" ^ injectN val ctor_exhaustN = ctorN ^ "_" ^ exhaustN val dtor_injectN = dtorN ^ "_" ^ injectN val dtor_exhaustN = dtorN ^ "_" ^ exhaustN val ctor_relN = ctorN ^ "_" ^ relN val dtor_relN = dtorN ^ "_" ^ relN val inductN = "induct" val coinductN = coN ^ inductN val ctor_inductN = ctorN ^ "_" ^ inductN val ctor_induct2N = ctor_inductN ^ "2" val dtor_map_coinductN = dtor_mapN ^ "_" ^ coinductN val dtor_coinductN = dtorN ^ "_" ^ coinductN val coinduct_strongN = coinductN ^ "_strong" val dtor_map_coinduct_strongN = dtor_mapN ^ "_" ^ coinduct_strongN val dtor_coinduct_strongN = dtorN ^ "_" ^ coinduct_strongN val colN = "col" val set_inclN = "set_incl" val ctor_set_inclN = ctorN ^ "_" ^ set_inclN val dtor_set_inclN = dtorN ^ "_" ^ set_inclN val set_set_inclN = "set_set_incl" val ctor_set_set_inclN = ctorN ^ "_" ^ set_set_inclN val dtor_set_set_inclN = dtorN ^ "_" ^ set_set_inclN val caseN = "case" val discN = "disc" val corec_discN = corecN ^ "_" ^ discN val iffN = "_iff" val corec_disc_iffN = corec_discN ^ iffN val distinctN = "distinct" val rel_distinctN = relN ^ "_" ^ distinctN val injectN = "inject" val rel_casesN = relN ^ "_cases" val rel_injectN = relN ^ "_" ^ injectN val rel_introsN = relN ^ "_intros" val rel_coinductN = relN ^ "_" ^ coinductN val rel_selN = relN ^ "_sel" val dtor_rel_coinductN = dtorN ^ "_" ^ rel_coinductN val rel_inductN = relN ^ "_" ^ inductN val ctor_rel_inductN = ctorN ^ "_" ^ rel_inductN val selN = "sel" val corec_selN = corecN ^ "_" ^ selN fun co_prefix fp = case_fp fp "" "co"; fun dest_sumT (Type (\<^type_name>\sum\, [T, T'])) = (T, T'); val dest_sumTN_balanced = Balanced_Tree.dest dest_sumT; fun dest_tupleT_balanced 0 \<^typ>\unit\ = [] | dest_tupleT_balanced n T = Balanced_Tree.dest HOLogic.dest_prodT n T; fun dest_absumprodT absT repT n ms = map2 dest_tupleT_balanced ms o dest_sumTN_balanced n o mk_repT absT repT; val mk_sumTN = Library.foldr1 mk_sumT; val mk_sumTN_balanced = Balanced_Tree.make mk_sumT; fun mk_tupleT_balanced [] = HOLogic.unitT | mk_tupleT_balanced Ts = Balanced_Tree.make HOLogic.mk_prodT Ts; val mk_sumprodT_balanced = mk_sumTN_balanced o map mk_tupleT_balanced; fun mk_proj T n k = let val (binders, _) = strip_typeN n T in fold_rev (fn T => fn t => Abs (Name.uu, T, t)) binders (Bound (n - k - 1)) end; fun mk_convol (f, g) = let val (fU, fTU) = `range_type (fastype_of f); val ((gT, gU), gTU) = `dest_funT (fastype_of g); val convolT = fTU --> gTU --> gT --> HOLogic.mk_prodT (fU, gU); in Const (\<^const_name>\convol\, convolT) $ f $ g end; fun mk_rel_prod R S = let val ((A1, A2), RT) = `dest_pred2T (fastype_of R); val ((B1, B2), ST) = `dest_pred2T (fastype_of S); val rel_prodT = RT --> ST --> mk_pred2T (HOLogic.mk_prodT (A1, B1)) (HOLogic.mk_prodT (A2, B2)); in Const (\<^const_name>\rel_prod\, rel_prodT) $ R $ S end; fun mk_rel_sum R S = let val ((A1, A2), RT) = `dest_pred2T (fastype_of R); val ((B1, B2), ST) = `dest_pred2T (fastype_of S); val rel_sumT = RT --> ST --> mk_pred2T (mk_sumT (A1, B1)) (mk_sumT (A2, B2)); in Const (\<^const_name>\rel_sum\, rel_sumT) $ R $ S end; fun Inl_const LT RT = Const (\<^const_name>\Inl\, LT --> mk_sumT (LT, RT)); fun mk_Inl RT t = Inl_const (fastype_of t) RT $ t; fun Inr_const LT RT = Const (\<^const_name>\Inr\, RT --> mk_sumT (LT, RT)); fun mk_Inr LT t = Inr_const LT (fastype_of t) $ t; fun mk_prod1 bound_Ts (t, u) = HOLogic.pair_const (fastype_of1 (bound_Ts, t)) (fastype_of1 (bound_Ts, u)) $ t $ u; fun mk_tuple1_balanced _ [] = HOLogic.unit | mk_tuple1_balanced bound_Ts ts = Balanced_Tree.make (mk_prod1 bound_Ts) ts; val mk_tuple_balanced = mk_tuple1_balanced []; fun abs_curried_balanced Ts t = t $ mk_tuple1_balanced (List.rev Ts) (map Bound (length Ts - 1 downto 0)) |> fold_rev (Term.abs o pair Name.uu) Ts; fun mk_sumprod_balanced T n k ts = Sum_Tree.mk_inj T n k (mk_tuple_balanced ts); fun mk_absumprod absT abs0 n k ts = let val abs = mk_abs absT abs0; in abs $ mk_sumprod_balanced (domain_type (fastype_of abs)) n k ts end; fun mk_case_sum (f, g) = let val (fT, T') = dest_funT (fastype_of f); val (gT, _) = dest_funT (fastype_of g); in Sum_Tree.mk_sumcase fT gT T' f g end; val mk_case_sumN = Library.foldr1 mk_case_sum; val mk_case_sumN_balanced = Balanced_Tree.make mk_case_sum; fun mk_tupled_fun f x xs = if xs = [x] then f else HOLogic.tupled_lambda x (Term.list_comb (f, xs)); fun mk_case_absumprod absT rep fs xss xss' = HOLogic.mk_comp (mk_case_sumN_balanced (@{map 3} mk_tupled_fun fs (map mk_tuple_balanced xss) xss'), mk_rep absT rep); fun If_const T = Const (\<^const_name>\If\, HOLogic.boolT --> T --> T --> T); fun mk_If p t f = let val T = fastype_of t in If_const T $ p $ t $ f end; fun mk_Field r = let val T = fst (dest_relT (fastype_of r)); in Const (\<^const_name>\Field\, mk_relT (T, T) --> HOLogic.mk_setT T) $ r end; (*dangerous; use with monotonic, converging functions only!*) fun fixpoint eq f X = if subset eq (f X, X) then X else fixpoint eq f (f X); (* stolen from "~~/src/HOL/Tools/Datatype/datatype_aux.ML" *) fun split_conj_thm th = ((th RS conjunct1) :: split_conj_thm (th RS conjunct2)) handle THM _ => [th]; fun split_conj_prems limit th = let fun split n i th = if i = n then th else split n (i + 1) (conjI RSN (i, th)) handle THM _ => th; in split limit 1 th end; fun mk_obj_sumEN_balanced n = Balanced_Tree.make (fn (thm1, thm2) => thm1 RSN (1, thm2 RSN (2, @{thm obj_sumE_f}))) (replicate n asm_rl); fun mk_tupled_allIN_balanced 0 = @{thm unit_all_impI} | mk_tupled_allIN_balanced n = let val (tfrees, _) = BNF_Util.mk_TFrees n \<^context>; val T = mk_tupleT_balanced tfrees; in @{thm asm_rl[of "\x. P x \ Q x" for P Q]} |> Thm.instantiate' [SOME (Thm.ctyp_of \<^context> T)] [] |> Raw_Simplifier.rewrite_goals_rule \<^context> @{thms split_paired_All[THEN eq_reflection]} |> (fn thm => impI RS funpow n (fn th => allI RS th) thm) |> Thm.varifyT_global end; fun mk_absumprodE type_definition ms = let val n = length ms in mk_obj_sumEN_balanced n OF map mk_tupled_allIN_balanced ms RS (type_definition RS @{thm type_copy_obj_one_point_absE}) end; fun mk_sum_caseN 1 1 = refl | mk_sum_caseN _ 1 = @{thm sum.case(1)} | mk_sum_caseN 2 2 = @{thm sum.case(2)} | mk_sum_caseN n k = trans OF [@{thm case_sum_step(2)}, mk_sum_caseN (n - 1) (k - 1)]; fun mk_sum_step base step thm = if Thm.eq_thm_prop (thm, refl) then base else trans OF [step, thm]; fun mk_sum_caseN_balanced 1 1 = refl | mk_sum_caseN_balanced n k = Balanced_Tree.access {left = mk_sum_step @{thm sum.case(1)} @{thm case_sum_step(1)}, right = mk_sum_step @{thm sum.case(2)} @{thm case_sum_step(2)}, init = refl} n k; fun mk_sum_Cinfinite [thm] = thm | mk_sum_Cinfinite (thm :: thms) = @{thm Cinfinite_csum_weak} OF [thm, mk_sum_Cinfinite thms]; fun mk_sum_card_order [thm] = thm | mk_sum_card_order (thm :: thms) = @{thm card_order_csum} OF [thm, mk_sum_card_order thms]; fun mk_xtor_rel_co_induct_thm fp pre_rels pre_phis rels phis xs ys xtors xtor's tac lthy = let val pre_relphis = map (fn rel => Term.list_comb (rel, phis @ pre_phis)) pre_rels; val relphis = map (fn rel => Term.list_comb (rel, phis)) rels; fun mk_xtor fp' xtor x = if fp = fp' then xtor $ x else x; val dtor = mk_xtor Greatest_FP; val ctor = mk_xtor Least_FP; fun flip f x y = if fp = Greatest_FP then f y x else f x y; fun mk_prem pre_relphi phi x y xtor xtor' = HOLogic.mk_Trueprop (list_all_free [x, y] (flip (curry HOLogic.mk_imp) (pre_relphi $ (dtor xtor x) $ (dtor xtor' y)) (phi $ (ctor xtor x) $ (ctor xtor' y)))); val prems = @{map 6} mk_prem pre_relphis pre_phis xs ys xtors xtor's; val concl = HOLogic.mk_Trueprop (Library.foldr1 HOLogic.mk_conj (map2 (flip mk_leq) relphis pre_phis)); in Goal.prove_sorry lthy (map (fst o dest_Free) (phis @ pre_phis)) prems concl tac |> Thm.close_derivation \<^here> |> (fn thm => thm OF (replicate (length pre_rels) @{thm allI[OF allI[OF impI]]})) end; fun mk_xtor_co_iter_transfer_thms fp pre_rels pre_iphis pre_ophis rels phis un_folds un_folds' tac lthy = let val pre_relphis = map (fn rel => Term.list_comb (rel, phis @ pre_iphis)) pre_rels; val relphis = map (fn rel => Term.list_comb (rel, phis)) rels; fun flip f x y = if fp = Greatest_FP then f y x else f x y; val arg_rels = map2 (flip mk_rel_fun) pre_relphis pre_ophis; fun mk_transfer relphi pre_phi un_fold un_fold' = fold_rev mk_rel_fun arg_rels (flip mk_rel_fun relphi pre_phi) $ un_fold $ un_fold'; val transfers = @{map 4} mk_transfer relphis pre_ophis un_folds un_folds'; val goal = fold_rev Logic.all (phis @ pre_ophis) (HOLogic.mk_Trueprop (Library.foldr1 HOLogic.mk_conj transfers)); in Goal.prove_sorry lthy [] [] goal tac |> Thm.close_derivation \<^here> |> split_conj_thm end; fun mk_xtor_co_iter_o_map_thms fp is_rec m un_fold_unique xtor_maps xtor_un_folds sym_map_comps map_cong0s = let val n = length sym_map_comps; val rewrite_comp_comp2 = case_fp fp @{thm rewriteR_comp_comp2} @{thm rewriteL_comp_comp2}; val rewrite_comp_comp = case_fp fp @{thm rewriteR_comp_comp} @{thm rewriteL_comp_comp}; val map_cong_passive_args1 = replicate m (case_fp fp @{thm id_comp} @{thm comp_id} RS fun_cong); val map_cong_active_args1 = replicate n (if is_rec then case_fp fp @{thm convol_o} @{thm o_case_sum} RS fun_cong else refl); val map_cong_passive_args2 = replicate m (case_fp fp @{thm comp_id} @{thm id_comp} RS fun_cong); val map_cong_active_args2 = replicate n (if is_rec then case_fp fp @{thm map_prod_o_convol_id} @{thm case_sum_o_map_sum_id} else case_fp fp @{thm id_comp} @{thm comp_id} RS fun_cong); fun mk_map_congs passive active = map (fn thm => thm OF (passive @ active) RS @{thm ext}) map_cong0s; val map_cong1s = mk_map_congs map_cong_passive_args1 map_cong_active_args1; val map_cong2s = mk_map_congs map_cong_passive_args2 map_cong_active_args2; fun mk_rewrites map_congs = map2 (fn sym_map_comp => fn map_cong => mk_trans sym_map_comp map_cong RS rewrite_comp_comp) sym_map_comps map_congs; val rewrite1s = mk_rewrites map_cong1s; val rewrite2s = mk_rewrites map_cong2s; val unique_prems = @{map 4} (fn xtor_map => fn un_fold => fn rewrite1 => fn rewrite2 => mk_trans (rewrite_comp_comp2 OF [xtor_map, un_fold]) (mk_trans rewrite1 (mk_sym rewrite2))) xtor_maps xtor_un_folds rewrite1s rewrite2s; in split_conj_thm (un_fold_unique OF map (case_fp fp I mk_sym) unique_prems) end; fun force_typ ctxt T = Term.map_types Type_Infer.paramify_vars #> Type.constraint T #> Syntax.check_term ctxt #> singleton (Variable.polymorphic ctxt); fun absT_info_encodeT thy (SOME (src : absT_info, dst : absT_info)) src_absT = let val src_repT = mk_repT (#absT src) (#repT src) src_absT; val dst_absT = mk_absT thy (#repT dst) (#absT dst) src_repT; in dst_absT end | absT_info_encodeT _ NONE T = T; fun absT_info_decodeT thy = absT_info_encodeT thy o Option.map swap; fun absT_info_encode thy fp (opt as SOME (src : absT_info, dst : absT_info)) t = let val co_alg_funT = case_fp fp domain_type range_type; fun co_swap pair = case_fp fp I swap pair; val mk_co_comp = curry (HOLogic.mk_comp o co_swap); val mk_co_abs = case_fp fp mk_abs mk_rep; val mk_co_rep = case_fp fp mk_rep mk_abs; val co_abs = case_fp fp #abs #rep; val co_rep = case_fp fp #rep #abs; val src_absT = co_alg_funT (fastype_of t); val dst_absT = absT_info_encodeT thy opt src_absT; val co_src_abs = mk_co_abs src_absT (co_abs src); val co_dst_rep = mk_co_rep dst_absT (co_rep dst); in mk_co_comp (mk_co_comp t co_src_abs) co_dst_rep end | absT_info_encode _ _ NONE t = t; fun absT_info_decode thy fp = absT_info_encode thy fp o Option.map swap; fun mk_xtor_un_fold_xtor_thms ctxt fp un_folds xtors xtor_un_fold_unique map_id0s absT_info_opts = let val thy = Proof_Context.theory_of ctxt; fun mk_goal un_fold = let val rhs = list_comb (un_fold, @{map 2} (absT_info_encode thy fp) absT_info_opts xtors); val T = range_type (fastype_of rhs); in HOLogic.mk_eq (HOLogic.id_const T, rhs) end; val goal = HOLogic.mk_Trueprop (Library.foldr1 HOLogic.mk_conj (map mk_goal un_folds)); fun mk_inverses NONE = [] | mk_inverses (SOME (src, dst)) = [#type_definition dst RS @{thm type_definition.Abs_inverse[OF _ UNIV_I]}, #type_definition src RS @{thm type_definition.Rep_inverse}]; val inverses = maps mk_inverses absT_info_opts; in Goal.prove_sorry ctxt [] [] goal (fn {context = ctxt, prems = _} => mk_xtor_un_fold_xtor_tac ctxt xtor_un_fold_unique map_id0s inverses) |> split_conj_thm |> map mk_sym end; fun derive_xtor_co_recs fp bs mk_Ts (Dss, resDs) pre_bnfs xtors0 un_folds0 xtor_un_fold_unique xtor_un_folds xtor_un_fold_transfers xtor_maps xtor_rels absT_info_opts lthy = let val thy = Proof_Context.theory_of lthy; fun co_swap pair = case_fp fp I swap pair; val mk_co_comp = curry (HOLogic.mk_comp o co_swap); fun mk_co_algT T U = case_fp fp (T --> U) (U --> T); val co_alg_funT = case_fp fp domain_type range_type; val mk_co_product = curry (case_fp fp mk_convol mk_case_sum); val co_proj1_const = case_fp fp fst_const (uncurry Inl_const o dest_sumT) o co_alg_funT; val co_proj2_const = case_fp fp snd_const (uncurry Inr_const o dest_sumT) o co_alg_funT; val mk_co_productT = curry (case_fp fp HOLogic.mk_prodT mk_sumT); val rewrite_comp_comp = case_fp fp @{thm rewriteL_comp_comp} @{thm rewriteR_comp_comp}; val n = length pre_bnfs; val live = live_of_bnf (hd pre_bnfs); val m = live - n; val ks = 1 upto n; val map_id0s = map map_id0_of_bnf pre_bnfs; val map_comps = map map_comp_of_bnf pre_bnfs; val map_cong0s = map map_cong0_of_bnf pre_bnfs; val map_transfers = map map_transfer_of_bnf pre_bnfs; val sym_map_comp0s = map (mk_sym o map_comp0_of_bnf) pre_bnfs; val deads = fold (union (op =)) Dss resDs; val ((((As, Bs), Xs), Ys), names_lthy) = lthy |> fold Variable.declare_typ deads |> mk_TFrees m ||>> mk_TFrees m ||>> mk_TFrees n ||>> mk_TFrees n; val XFTs = @{map 2} (fn Ds => mk_T_of_bnf Ds (As @ Xs)) Dss pre_bnfs; val co_algXFTs = @{map 2} mk_co_algT XFTs Xs; val Ts = mk_Ts As; val un_foldTs = @{map 2} (fn T => fn X => co_algXFTs ---> mk_co_algT T X) Ts Xs; val un_folds = @{map 2} (force_typ names_lthy) un_foldTs un_folds0; val ABs = As ~~ Bs; val XYs = Xs ~~ Ys; val Us = map (typ_subst_atomic ABs) Ts; val TFTs = @{map 2} (fn Ds => mk_T_of_bnf Ds (As @ Ts)) Dss pre_bnfs; val TFTs' = @{map 2} (absT_info_decodeT thy) absT_info_opts TFTs; val xtors = @{map 3} (force_typ names_lthy oo mk_co_algT) TFTs' Ts xtors0; val ids = map HOLogic.id_const As; val co_rec_Xs = @{map 2} mk_co_productT Ts Xs; val co_rec_Ys = @{map 2} mk_co_productT Us Ys; val co_rec_algXs = @{map 2} mk_co_algT co_rec_Xs Xs; val co_proj1s = map co_proj1_const co_rec_algXs; val co_rec_maps = @{map 2} (fn Ds => mk_map_of_bnf Ds (As @ case_fp fp co_rec_Xs Ts) (As @ case_fp fp Ts co_rec_Xs)) Dss pre_bnfs; val co_rec_Ts = @{map 2} (fn Ds => mk_T_of_bnf Ds (As @ co_rec_Xs)) Dss pre_bnfs val co_rec_argTs = @{map 2} mk_co_algT co_rec_Ts Xs; val co_rec_resTs = @{map 2} mk_co_algT Ts Xs; val (((co_rec_ss, fs), xs), names_lthy) = names_lthy |> mk_Frees "s" co_rec_argTs ||>> mk_Frees "f" co_rec_resTs ||>> mk_Frees "x" (case_fp fp TFTs' Xs); val co_rec_strs = @{map 4} (fn xtor => fn s => fn mapx => fn absT_info_opt => mk_co_product (mk_co_comp (absT_info_encode thy fp absT_info_opt xtor) (list_comb (mapx, ids @ co_proj1s))) s) xtors co_rec_ss co_rec_maps absT_info_opts; val theta = Xs ~~ co_rec_Xs; val co_rec_un_folds = map (subst_atomic_types theta) un_folds; val co_rec_spec0s = map (fn un_fold => list_comb (un_fold, co_rec_strs)) co_rec_un_folds; val co_rec_ids = @{map 2} (mk_co_comp o co_proj1_const) co_rec_algXs co_rec_spec0s; val co_rec_specs = @{map 2} (mk_co_comp o co_proj2_const) co_rec_algXs co_rec_spec0s; val co_recN = case_fp fp ctor_recN dtor_corecN; fun co_rec_bind i = nth bs (i - 1) |> Binding.prefix_name (co_recN ^ "_"); val co_rec_def_bind = rpair [] o Binding.concealed o Thm.def_binding o co_rec_bind; fun co_rec_spec i = fold_rev (Term.absfree o Term.dest_Free) co_rec_ss (nth co_rec_specs (i - 1)); val ((co_rec_frees, (_, co_rec_def_frees)), (lthy, lthy_old)) = lthy |> Local_Theory.open_target |> snd |> fold_map (fn i => Local_Theory.define ((co_rec_bind i, NoSyn), (co_rec_def_bind i, co_rec_spec i))) ks |>> apsnd split_list o split_list ||> `Local_Theory.close_target; val phi = Proof_Context.export_morphism lthy_old lthy; val co_rec_names = map (fst o dest_Const o Morphism.term phi) co_rec_frees; val co_recs = @{map 2} (fn name => fn resT => Const (name, co_rec_argTs ---> resT)) co_rec_names co_rec_resTs; val co_rec_defs = map (fn def => mk_unabs_def n (HOLogic.mk_obj_eq (Morphism.thm phi def))) co_rec_def_frees; val xtor_un_fold_xtor_thms = mk_xtor_un_fold_xtor_thms lthy fp (map (Term.subst_atomic_types (Xs ~~ Ts)) un_folds) xtors xtor_un_fold_unique map_id0s absT_info_opts; val co_rec_id_thms = let val goal = @{map 2} (fn T => fn t => HOLogic.mk_eq (t, HOLogic.id_const T)) Ts co_rec_ids |> Library.foldr1 HOLogic.mk_conj |> HOLogic.mk_Trueprop; val vars = Variable.add_free_names lthy goal []; in Goal.prove_sorry lthy vars [] goal (fn {context = ctxt, prems = _} => mk_xtor_co_rec_id_tac ctxt xtor_un_fold_xtor_thms xtor_un_fold_unique xtor_un_folds map_comps) |> Thm.close_derivation \<^here> |> split_conj_thm end; val co_rec_app_ss = map (fn co_rec => list_comb (co_rec, co_rec_ss)) co_recs; val co_products = @{map 2} (fn T => mk_co_product (HOLogic.id_const T)) Ts co_rec_app_ss; val co_rec_maps_rev = @{map 2} (fn Ds => mk_map_of_bnf Ds (As @ case_fp fp Ts co_rec_Xs) (As @ case_fp fp co_rec_Xs Ts)) Dss pre_bnfs; fun mk_co_app f g x = case_fp fp (f $ (g $ x)) (g $ (f $ x)); val co_rec_expand_thms = map (fn thm => thm RS case_fp fp @{thm convol_expand_snd} @{thm case_sum_expand_Inr_pointfree}) co_rec_id_thms; val xtor_co_rec_thms = let fun mk_goal co_rec s mapx xtor x absT_info_opt = let val lhs = mk_co_app co_rec xtor x; val rhs = mk_co_app s (list_comb (mapx, ids @ co_products) |> absT_info_decode thy fp absT_info_opt) x; in mk_Trueprop_eq (lhs, rhs) end; val goals = @{map 6} mk_goal co_rec_app_ss co_rec_ss co_rec_maps_rev xtors xs absT_info_opts; in map2 (fn goal => fn un_fold => Variable.add_free_names lthy goal [] |> (fn vars => Goal.prove_sorry lthy vars [] goal (fn {context = ctxt, prems = _} => mk_xtor_co_rec_tac ctxt un_fold co_rec_defs co_rec_expand_thms)) |> Thm.close_derivation \<^here>) goals xtor_un_folds end; val co_product_fs = @{map 2} (fn T => mk_co_product (HOLogic.id_const T)) Ts fs; val co_rec_expand'_thms = map (fn thm => thm RS case_fp fp @{thm convol_expand_snd'} @{thm case_sum_expand_Inr'}) co_rec_id_thms; val xtor_co_rec_unique_thm = let fun mk_prem f s mapx xtor absT_info_opt = let val lhs = mk_co_comp f xtor; val rhs = mk_co_comp s (list_comb (mapx, ids @ co_product_fs)) |> absT_info_decode thy fp absT_info_opt; in mk_Trueprop_eq (co_swap (lhs, rhs)) end; val prems = @{map 5} mk_prem fs co_rec_ss co_rec_maps_rev xtors absT_info_opts; val concl = @{map 2} (curry HOLogic.mk_eq) fs co_rec_app_ss |> Library.foldr1 HOLogic.mk_conj |> HOLogic.mk_Trueprop; val goal = Logic.list_implies (prems, concl); val vars = Variable.add_free_names lthy goal []; fun mk_inverses NONE = [] | mk_inverses (SOME (src, dst)) = [#type_definition dst RS @{thm type_copy_Rep_o_Abs} RS rewrite_comp_comp, #type_definition src RS @{thm type_copy_Abs_o_Rep}]; val inverses = maps mk_inverses absT_info_opts; in Goal.prove_sorry lthy vars [] goal (fn {context = ctxt, prems = _} => mk_xtor_co_rec_unique_tac ctxt fp co_rec_defs co_rec_expand'_thms xtor_un_fold_unique map_id0s sym_map_comp0s inverses) |> Thm.close_derivation \<^here> end; val xtor_co_rec_o_map_thms = if forall is_none absT_info_opts then mk_xtor_co_iter_o_map_thms fp true m xtor_co_rec_unique_thm (map (mk_pointfree2 lthy) xtor_maps) (map (mk_pointfree2 lthy) xtor_co_rec_thms) sym_map_comp0s map_cong0s else replicate n refl (* FIXME *); val ABphiTs = @{map 2} mk_pred2T As Bs; val XYphiTs = @{map 2} mk_pred2T Xs Ys; val ((ABphis, XYphis), names_lthy) = names_lthy |> mk_Frees "R" ABphiTs ||>> mk_Frees "S" XYphiTs; val xtor_co_rec_transfer_thms = if forall is_none absT_info_opts then let val pre_rels = @{map 2} (fn Ds => mk_rel_of_bnf Ds (As @ co_rec_Xs) (Bs @ co_rec_Ys)) Dss pre_bnfs; val rels = @{map 3} (fn T => fn T' => Thm.prop_of #> HOLogic.dest_Trueprop #> fst o dest_comb #> fst o dest_comb #> funpow n (snd o dest_comb) #> case_fp fp (fst o dest_comb #> snd o dest_comb) (snd o dest_comb) #> head_of #> force_typ names_lthy (ABphiTs ---> mk_pred2T T T')) Ts Us xtor_un_fold_transfers; fun tac {context = ctxt, prems = _} = mk_xtor_co_rec_transfer_tac ctxt fp n m co_rec_defs xtor_un_fold_transfers map_transfers xtor_rels; val mk_rel_co_product = case_fp fp mk_rel_prod mk_rel_sum; val rec_phis = map2 (fn rel => mk_rel_co_product (Term.list_comb (rel, ABphis))) rels XYphis; in mk_xtor_co_iter_transfer_thms fp pre_rels rec_phis XYphis rels ABphis co_recs (map (subst_atomic_types (ABs @ XYs)) co_recs) tac lthy end else replicate n TrueI (* FIXME *); val notes = [(case_fp fp ctor_recN dtor_corecN, xtor_co_rec_thms), (case_fp fp ctor_rec_uniqueN dtor_corec_uniqueN, split_conj_thm xtor_co_rec_unique_thm), (case_fp fp ctor_rec_o_mapN dtor_corec_o_mapN, xtor_co_rec_o_map_thms), (case_fp fp ctor_rec_transferN dtor_corec_transferN, xtor_co_rec_transfer_thms)] |> map (apsnd (map single)) |> maps (fn (thmN, thmss) => map2 (fn b => fn thms => ((Binding.qualify true (Binding.name_of b) (Binding.name thmN), []), [(thms, [])])) bs thmss); val lthy = lthy |> Config.get lthy bnf_internals ? snd o Local_Theory.notes notes; in ((co_recs, (xtor_co_rec_thms, xtor_co_rec_unique_thm, xtor_co_rec_o_map_thms, xtor_co_rec_transfer_thms)), lthy) end; fun raw_qualify extra_qualify base_b = let val qs = Binding.path_of base_b; val n = Binding.name_of base_b; in Binding.prefix_name rawN #> fold_rev (fn (s, mand) => Binding.qualify mand s) (qs @ [(n, true)]) #> extra_qualify #> Binding.concealed end; fun fixpoint_bnf force_out_of_line extra_qualify construct_fp bs resBs Ds0 Xs rhsXs comp_cache0 lthy = let val time = time lthy; val timer = time (Timer.startRealTimer ()); fun flatten_tyargs Ass = subtract (op =) Xs (filter (fn T => exists (fn Ts => member (op =) Ts T) Ass) resBs) @ Xs; val ((bnfs, (deadss, livess)), (comp_cache_unfold_set, lthy')) = apfst (apsnd split_list o split_list) (@{fold_map 2} (fn b => bnf_of_typ true Smart_Inline (raw_qualify extra_qualify b) flatten_tyargs Xs Ds0) bs rhsXs ((comp_cache0, empty_unfolds), lthy)); fun norm_qualify i = Binding.qualify true (Binding.name_of (nth bs (Int.max (0, i - 1)))) #> extra_qualify #> Binding.concealed; val Ass = map (map dest_TFree) livess; val Ds' = fold (fold Term.add_tfreesT) deadss []; val Ds = union (op =) Ds' Ds0; val missing = resBs |> fold (subtract (op =)) (Ds' :: Ass); val (dead_phantoms, live_phantoms) = List.partition (member (op =) Ds0) missing; val resBs' = resBs |> fold (subtract (op =)) [dead_phantoms, Ds]; val timer = time (timer "Construction of BNFs"); val ((kill_posss, _), (bnfs', ((comp_cache', unfold_set'), lthy''))) = normalize_bnfs norm_qualify Ass Ds (K (resBs' @ Xs)) bnfs (comp_cache_unfold_set, lthy'); val Dss = @{map 3} (fn lives => fn kill_posss => fn deads => deads @ map (nth lives) kill_posss) livess kill_posss deadss; val all_Dss = Dss |> force_out_of_line ? map (fn Ds' => union (op =) Ds' (map TFree Ds0)); fun pre_qualify b = Binding.qualify false (Binding.name_of b) #> extra_qualify #> not (Config.get lthy'' bnf_internals) ? Binding.concealed; val ((pre_bnfs, (deadss, absT_infos)), lthy''') = lthy'' |> @{fold_map 5} (fn b => seal_bnf (pre_qualify b) unfold_set' (Binding.prefix_name preN b)) bs (replicate (length rhsXs) (force_out_of_line orelse not (null live_phantoms))) Dss all_Dss bnfs' |>> split_list |>> apsnd split_list; val timer = time (timer "Normalization & sealing of BNFs"); val res = construct_fp bs resBs (map TFree dead_phantoms, deadss) pre_bnfs absT_infos lthy'''; val timer = time (timer "FP construction in total"); in (((pre_bnfs, absT_infos), comp_cache'), res) end; - -(** document antiquotations **) - -local - -fun antiquote_setup binding co = - Thy_Output.antiquotation_pretty_source_embedded binding - ((Scan.ahead (Scan.lift Parse.not_eof) >> Token.pos_of) -- - Args.type_name {proper = true, strict = true}) - (fn ctxt => fn (pos, type_name) => - let - fun err () = - error ("Bad " ^ Binding.name_of binding ^ ": " ^ quote type_name ^ Position.here pos); - in - (case Ctr_Sugar.ctr_sugar_of ctxt type_name of - NONE => err () - | SOME {kind, T = T0, ctrs = ctrs0, ...} => - let - val _ = if co = (kind = Codatatype) then () else err (); - - val T = Logic.unvarifyT_global T0; - val ctrs = map Logic.unvarify_global ctrs0; - - val pretty_typ_bracket = Syntax.pretty_typ (Config.put pretty_priority 1001 ctxt); - fun pretty_ctr ctr = - Pretty.block (Pretty.breaks (Syntax.pretty_term ctxt ctr :: - map pretty_typ_bracket (binder_types (fastype_of ctr)))); - in - Pretty.block (Pretty.keyword1 (Binding.name_of binding) :: Pretty.brk 1 :: - Syntax.pretty_typ ctxt T :: Pretty.str " =" :: Pretty.brk 1 :: - flat (separate [Pretty.brk 1, Pretty.str "| "] (map (single o pretty_ctr) ctrs))) - end) - end); - -in - -val _ = - Theory.setup - (antiquote_setup \<^binding>\datatype\ false #> - antiquote_setup \<^binding>\codatatype\ true); - end; - -end; diff --git a/src/HOL/Tools/Ctr_Sugar/ctr_sugar.ML b/src/HOL/Tools/Ctr_Sugar/ctr_sugar.ML --- a/src/HOL/Tools/Ctr_Sugar/ctr_sugar.ML +++ b/src/HOL/Tools/Ctr_Sugar/ctr_sugar.ML @@ -1,1190 +1,1234 @@ (* Title: HOL/Tools/Ctr_Sugar/ctr_sugar.ML Author: Jasmin Blanchette, TU Muenchen Author: Martin Desharnais, TU Muenchen Copyright 2012, 2013 Wrapping existing freely generated type's constructors. *) signature CTR_SUGAR = sig datatype ctr_sugar_kind = Datatype | Codatatype | Record | Unknown type ctr_sugar = {kind: ctr_sugar_kind, T: typ, ctrs: term list, casex: term, discs: term list, selss: term list list, exhaust: thm, nchotomy: thm, injects: thm list, distincts: thm list, case_thms: thm list, case_cong: thm, case_cong_weak: thm, case_distribs: thm list, split: thm, split_asm: thm, disc_defs: thm list, disc_thmss: thm list list, discIs: thm list, disc_eq_cases: thm list, sel_defs: thm list, sel_thmss: thm list list, distinct_discsss: thm list list list, exhaust_discs: thm list, exhaust_sels: thm list, collapses: thm list, expands: thm list, split_sels: thm list, split_sel_asms: thm list, case_eq_ifs: thm list}; val morph_ctr_sugar: morphism -> ctr_sugar -> ctr_sugar val transfer_ctr_sugar: theory -> ctr_sugar -> ctr_sugar val ctr_sugar_of: Proof.context -> string -> ctr_sugar option val ctr_sugar_of_global: theory -> string -> ctr_sugar option val ctr_sugars_of: Proof.context -> ctr_sugar list val ctr_sugars_of_global: theory -> ctr_sugar list val ctr_sugar_of_case: Proof.context -> string -> ctr_sugar option val ctr_sugar_of_case_global: theory -> string -> ctr_sugar option val ctr_sugar_interpretation: string -> (ctr_sugar -> local_theory -> local_theory) -> theory -> theory val interpret_ctr_sugar: (string -> bool) -> ctr_sugar -> local_theory -> local_theory val register_ctr_sugar_raw: ctr_sugar -> local_theory -> local_theory val register_ctr_sugar: (string -> bool) -> ctr_sugar -> local_theory -> local_theory val default_register_ctr_sugar_global: (string -> bool) -> ctr_sugar -> theory -> theory val mk_half_pairss: 'a list * 'a list -> ('a * 'a) list list val join_halves: int -> 'a list list -> 'a list list -> 'a list * 'a list list list val mk_ctr: typ list -> term -> term val mk_case: typ list -> typ -> term -> term val mk_disc_or_sel: typ list -> term -> term val name_of_ctr: term -> string val name_of_disc: term -> string val dest_ctr: Proof.context -> string -> term -> term * term list val dest_case: Proof.context -> string -> typ list -> term -> (ctr_sugar * term list * term list) option type ('c, 'a) ctr_spec = (binding * 'c) * 'a list val disc_of_ctr_spec: ('c, 'a) ctr_spec -> binding val ctr_of_ctr_spec: ('c, 'a) ctr_spec -> 'c val args_of_ctr_spec: ('c, 'a) ctr_spec -> 'a list val code_plugin: string type ctr_options = (string -> bool) * bool type ctr_options_cmd = (Proof.context -> string -> bool) * bool val fake_local_theory_for_sel_defaults: (binding * typ) list -> Proof.context -> Proof.context val free_constructors: ctr_sugar_kind -> ({prems: thm list, context: Proof.context} -> tactic) list list -> ((ctr_options * binding) * (term, binding) ctr_spec list) * term list -> local_theory -> ctr_sugar * local_theory val free_constructors_cmd: ctr_sugar_kind -> ((((Proof.context -> Plugin_Name.filter) * bool) * binding) * ((binding * string) * binding list) list) * string list -> Proof.context -> Proof.state val default_ctr_options: ctr_options val default_ctr_options_cmd: ctr_options_cmd val parse_bound_term: (binding * string) parser val parse_ctr_options: ctr_options_cmd parser val parse_ctr_spec: 'c parser -> 'a parser -> ('c, 'a) ctr_spec parser val parse_sel_default_eqs: string list parser end; structure Ctr_Sugar : CTR_SUGAR = struct open Ctr_Sugar_Util open Ctr_Sugar_Tactics open Ctr_Sugar_Code datatype ctr_sugar_kind = Datatype | Codatatype | Record | Unknown; type ctr_sugar = {kind: ctr_sugar_kind, T: typ, ctrs: term list, casex: term, discs: term list, selss: term list list, exhaust: thm, nchotomy: thm, injects: thm list, distincts: thm list, case_thms: thm list, case_cong: thm, case_cong_weak: thm, case_distribs: thm list, split: thm, split_asm: thm, disc_defs: thm list, disc_thmss: thm list list, discIs: thm list, disc_eq_cases: thm list, sel_defs: thm list, sel_thmss: thm list list, distinct_discsss: thm list list list, exhaust_discs: thm list, exhaust_sels: thm list, collapses: thm list, expands: thm list, split_sels: thm list, split_sel_asms: thm list, case_eq_ifs: thm list}; fun morph_ctr_sugar phi ({kind, T, ctrs, casex, discs, selss, exhaust, nchotomy, injects, distincts, case_thms, case_cong, case_cong_weak, case_distribs, split, split_asm, disc_defs, disc_thmss, discIs, disc_eq_cases, sel_defs, sel_thmss, distinct_discsss, exhaust_discs, exhaust_sels, collapses, expands, split_sels, split_sel_asms, case_eq_ifs} : ctr_sugar) = {kind = kind, T = Morphism.typ phi T, ctrs = map (Morphism.term phi) ctrs, casex = Morphism.term phi casex, discs = map (Morphism.term phi) discs, selss = map (map (Morphism.term phi)) selss, exhaust = Morphism.thm phi exhaust, nchotomy = Morphism.thm phi nchotomy, injects = map (Morphism.thm phi) injects, distincts = map (Morphism.thm phi) distincts, case_thms = map (Morphism.thm phi) case_thms, case_cong = Morphism.thm phi case_cong, case_cong_weak = Morphism.thm phi case_cong_weak, case_distribs = map (Morphism.thm phi) case_distribs, split = Morphism.thm phi split, split_asm = Morphism.thm phi split_asm, disc_defs = map (Morphism.thm phi) disc_defs, disc_thmss = map (map (Morphism.thm phi)) disc_thmss, discIs = map (Morphism.thm phi) discIs, disc_eq_cases = map (Morphism.thm phi) disc_eq_cases, sel_defs = map (Morphism.thm phi) sel_defs, sel_thmss = map (map (Morphism.thm phi)) sel_thmss, distinct_discsss = map (map (map (Morphism.thm phi))) distinct_discsss, exhaust_discs = map (Morphism.thm phi) exhaust_discs, exhaust_sels = map (Morphism.thm phi) exhaust_sels, collapses = map (Morphism.thm phi) collapses, expands = map (Morphism.thm phi) expands, split_sels = map (Morphism.thm phi) split_sels, split_sel_asms = map (Morphism.thm phi) split_sel_asms, case_eq_ifs = map (Morphism.thm phi) case_eq_ifs}; val transfer_ctr_sugar = morph_ctr_sugar o Morphism.transfer_morphism; structure Data = Generic_Data ( type T = ctr_sugar Symtab.table; val empty = Symtab.empty; val extend = I; fun merge data : T = Symtab.merge (K true) data; ); fun ctr_sugar_of_generic context = Option.map (transfer_ctr_sugar (Context.theory_of context)) o Symtab.lookup (Data.get context); fun ctr_sugars_of_generic context = Symtab.fold (cons o transfer_ctr_sugar (Context.theory_of context) o snd) (Data.get context) []; fun ctr_sugar_of_case_generic context s = find_first (fn {casex = Const (s', _), ...} => s' = s | _ => false) (ctr_sugars_of_generic context); val ctr_sugar_of = ctr_sugar_of_generic o Context.Proof; val ctr_sugar_of_global = ctr_sugar_of_generic o Context.Theory; val ctr_sugars_of = ctr_sugars_of_generic o Context.Proof; val ctr_sugars_of_global = ctr_sugars_of_generic o Context.Theory; val ctr_sugar_of_case = ctr_sugar_of_case_generic o Context.Proof; val ctr_sugar_of_case_global = ctr_sugar_of_case_generic o Context.Theory; structure Ctr_Sugar_Plugin = Plugin(type T = ctr_sugar); fun ctr_sugar_interpretation name f = Ctr_Sugar_Plugin.interpretation name (fn ctr_sugar => fn lthy => f (transfer_ctr_sugar (Proof_Context.theory_of lthy) ctr_sugar) lthy); val interpret_ctr_sugar = Ctr_Sugar_Plugin.data; fun register_ctr_sugar_raw (ctr_sugar as {T = Type (s, _), ...}) = Local_Theory.declaration {syntax = false, pervasive = true} (fn phi => Data.map (Symtab.update (s, morph_ctr_sugar phi ctr_sugar))); fun register_ctr_sugar plugins ctr_sugar = register_ctr_sugar_raw ctr_sugar #> interpret_ctr_sugar plugins ctr_sugar; fun default_register_ctr_sugar_global plugins (ctr_sugar as {T = Type (s, _), ...}) thy = let val tab = Data.get (Context.Theory thy) in if Symtab.defined tab s then thy else thy |> Context.theory_map (Data.put (Symtab.update_new (s, ctr_sugar) tab)) |> Named_Target.theory_map (Ctr_Sugar_Plugin.data plugins ctr_sugar) end; val is_prefix = "is_"; val un_prefix = "un_"; val not_prefix = "not_"; fun mk_unN 1 1 suf = un_prefix ^ suf | mk_unN _ l suf = un_prefix ^ suf ^ string_of_int l; val caseN = "case"; val case_congN = "case_cong"; val case_eq_ifN = "case_eq_if"; val collapseN = "collapse"; val discN = "disc"; val disc_eq_caseN = "disc_eq_case"; val discIN = "discI"; val distinctN = "distinct"; val distinct_discN = "distinct_disc"; val exhaustN = "exhaust"; val exhaust_discN = "exhaust_disc"; val expandN = "expand"; val injectN = "inject"; val nchotomyN = "nchotomy"; val selN = "sel"; val exhaust_selN = "exhaust_sel"; val splitN = "split"; val split_asmN = "split_asm"; val split_selN = "split_sel"; val split_sel_asmN = "split_sel_asm"; val splitsN = "splits"; val split_selsN = "split_sels"; val case_cong_weak_thmsN = "case_cong_weak"; val case_distribN = "case_distrib"; val cong_attrs = @{attributes [cong]}; val dest_attrs = @{attributes [dest]}; val safe_elim_attrs = @{attributes [elim!]}; val iff_attrs = @{attributes [iff]}; val inductsimp_attrs = @{attributes [induct_simp]}; val nitpicksimp_attrs = @{attributes [nitpick_simp]}; val simp_attrs = @{attributes [simp]}; fun unflat_lookup eq xs ys = map (fn xs' => permute_like_unique eq xs xs' ys); fun mk_half_pairss' _ ([], []) = [] | mk_half_pairss' indent (x :: xs, _ :: ys) = indent @ fold_rev (cons o single o pair x) ys (mk_half_pairss' ([] :: indent) (xs, ys)); fun mk_half_pairss p = mk_half_pairss' [[]] p; fun join_halves n half_xss other_half_xss = (splice (flat half_xss) (flat other_half_xss), map2 (map2 append) (Library.chop_groups n half_xss) (transpose (Library.chop_groups n other_half_xss))); fun mk_undefined T = Const (\<^const_name>\undefined\, T); fun mk_ctr Ts t = let val Type (_, Ts0) = body_type (fastype_of t) in subst_nonatomic_types (Ts0 ~~ Ts) t end; fun mk_case Ts T t = let val (Type (_, Ts0), body) = strip_type (fastype_of t) |>> List.last in subst_nonatomic_types ((body, T) :: (Ts0 ~~ Ts)) t end; fun mk_disc_or_sel Ts t = subst_nonatomic_types (snd (Term.dest_Type (domain_type (fastype_of t))) ~~ Ts) t; val name_of_ctr = name_of_const "constructor" body_type; fun name_of_disc t = (case head_of t of Abs (_, _, \<^const>\Not\ $ (t' $ Bound 0)) => Long_Name.map_base_name (prefix not_prefix) (name_of_disc t') | Abs (_, _, Const (\<^const_name>\HOL.eq\, _) $ Bound 0 $ t') => Long_Name.map_base_name (prefix is_prefix) (name_of_disc t') | Abs (_, _, \<^const>\Not\ $ (Const (\<^const_name>\HOL.eq\, _) $ Bound 0 $ t')) => Long_Name.map_base_name (prefix (not_prefix ^ is_prefix)) (name_of_disc t') | t' => name_of_const "discriminator" (perhaps (try domain_type)) t'); val base_name_of_ctr = Long_Name.base_name o name_of_ctr; fun dest_ctr ctxt s t = let val (f, args) = Term.strip_comb t in (case ctr_sugar_of ctxt s of SOME {ctrs, ...} => (case find_first (can (fo_match ctxt f)) ctrs of SOME f' => (f', args) | NONE => raise Fail "dest_ctr") | NONE => raise Fail "dest_ctr") end; fun dest_case ctxt s Ts t = (case Term.strip_comb t of (Const (c, _), args as _ :: _) => (case ctr_sugar_of ctxt s of SOME (ctr_sugar as {casex = Const (case_name, _), discs = discs0, selss = selss0, ...}) => if case_name = c then let val n = length discs0 in if n < length args then let val (branches, obj :: leftovers) = chop n args; val discs = map (mk_disc_or_sel Ts) discs0; val selss = map (map (mk_disc_or_sel Ts)) selss0; val conds = map (rapp obj) discs; val branch_argss = map (fn sels => map (rapp obj) sels @ leftovers) selss; val branches' = map2 (curry Term.betapplys) branches branch_argss; in SOME (ctr_sugar, conds, branches') end else NONE end else NONE | _ => NONE) | _ => NONE); fun const_or_free_name (Const (s, _)) = Long_Name.base_name s | const_or_free_name (Free (s, _)) = s | const_or_free_name t = raise TERM ("const_or_free_name", [t]) fun extract_sel_default ctxt t = let fun malformed () = error ("Malformed selector default value equation: " ^ Syntax.string_of_term ctxt t); val ((sel, (ctr, vars)), rhs) = fst (Term.replace_dummy_patterns (Syntax.check_term ctxt t) 0) |> HOLogic.dest_eq |>> (Term.dest_comb #>> const_or_free_name ##> (Term.strip_comb #>> (Term.dest_Const #> fst))) handle TERM _ => malformed (); in if forall (is_Free orf is_Var) vars andalso not (has_duplicates (op aconv) vars) then ((ctr, sel), fold_rev Term.lambda vars rhs) else malformed () end; (* Ideally, we would enrich the context with constants rather than free variables. *) fun fake_local_theory_for_sel_defaults sel_bTs = Proof_Context.allow_dummies #> Proof_Context.add_fixes (map (fn (b, T) => (b, SOME T, NoSyn)) sel_bTs) #> snd; type ('c, 'a) ctr_spec = (binding * 'c) * 'a list; fun disc_of_ctr_spec ((disc, _), _) = disc; fun ctr_of_ctr_spec ((_, ctr), _) = ctr; fun args_of_ctr_spec (_, args) = args; val code_plugin = Plugin_Name.declare_setup \<^binding>\code\; fun prepare_free_constructors kind prep_plugins prep_term ((((raw_plugins, discs_sels), raw_case_binding), ctr_specs), sel_default_eqs) no_defs_lthy = let val plugins = prep_plugins no_defs_lthy raw_plugins; (* TODO: sanity checks on arguments *) val raw_ctrs = map ctr_of_ctr_spec ctr_specs; val raw_disc_bindings = map disc_of_ctr_spec ctr_specs; val raw_sel_bindingss = map args_of_ctr_spec ctr_specs; val n = length raw_ctrs; val ks = 1 upto n; val _ = n > 0 orelse error "No constructors specified"; val ctrs0 = map (prep_term no_defs_lthy) raw_ctrs; val (fcT_name, As0) = (case body_type (fastype_of (hd ctrs0)) of Type T' => T' | _ => error "Expected type constructor in body type of constructor"); val _ = forall ((fn Type (T_name, _) => T_name = fcT_name | _ => false) o body_type o fastype_of) (tl ctrs0) orelse error "Constructors not constructing same type"; val fc_b_name = Long_Name.base_name fcT_name; val fc_b = Binding.name fc_b_name; fun qualify mandatory = Binding.qualify mandatory fc_b_name; val (unsorted_As, [B, C]) = no_defs_lthy |> variant_tfrees (map (fst o dest_TFree_or_TVar) As0) ||> fst o mk_TFrees 2; val As = map2 (resort_tfree_or_tvar o snd o dest_TFree_or_TVar) As0 unsorted_As; val fcT = Type (fcT_name, As); val ctrs = map (mk_ctr As) ctrs0; val ctr_Tss = map (binder_types o fastype_of) ctrs; val ms = map length ctr_Tss; fun can_definitely_rely_on_disc k = not (Binding.is_empty (nth raw_disc_bindings (k - 1))) orelse nth ms (k - 1) = 0; fun can_rely_on_disc k = can_definitely_rely_on_disc k orelse (k = 1 andalso not (can_definitely_rely_on_disc 2)); fun should_omit_disc_binding k = n = 1 orelse (n = 2 andalso can_rely_on_disc (3 - k)); val equal_binding = \<^binding>\=\; fun is_disc_binding_valid b = not (Binding.is_empty b orelse Binding.eq_name (b, equal_binding)); val standard_disc_binding = Binding.name o prefix is_prefix o base_name_of_ctr; val disc_bindings = raw_disc_bindings |> @{map 4} (fn k => fn m => fn ctr => fn disc => qualify false (if Binding.is_empty disc then if m = 0 then equal_binding else if should_omit_disc_binding k then disc else standard_disc_binding ctr else if Binding.eq_name (disc, standard_binding) then standard_disc_binding ctr else disc)) ks ms ctrs0; fun standard_sel_binding m l = Binding.name o mk_unN m l o base_name_of_ctr; val sel_bindingss = @{map 3} (fn ctr => fn m => map2 (fn l => fn sel => qualify false (if Binding.is_empty sel orelse Binding.eq_name (sel, standard_binding) then standard_sel_binding m l ctr else sel)) (1 upto m) o pad_list Binding.empty m) ctrs0 ms raw_sel_bindingss; val add_bindings = Variable.add_fixes (distinct (op =) (filter Symbol_Pos.is_identifier (map Binding.name_of (disc_bindings @ flat sel_bindingss)))) #> snd; val case_Ts = map (fn Ts => Ts ---> B) ctr_Tss; val (((((((((u, exh_y), xss), yss), fs), gs), w), (p, p'))), _) = no_defs_lthy |> add_bindings |> yield_singleton (mk_Frees fc_b_name) fcT ||>> yield_singleton (mk_Frees "y") fcT (* for compatibility with "datatype_realizer.ML" *) ||>> mk_Freess "x" ctr_Tss ||>> mk_Freess "y" ctr_Tss ||>> mk_Frees "f" case_Ts ||>> mk_Frees "g" case_Ts ||>> yield_singleton (mk_Frees "z") B ||>> yield_singleton (apfst (op ~~) oo mk_Frees' "P") HOLogic.boolT; val q = Free (fst p', mk_pred1T B); val xctrs = map2 (curry Term.list_comb) ctrs xss; val yctrs = map2 (curry Term.list_comb) ctrs yss; val xfs = map2 (curry Term.list_comb) fs xss; val xgs = map2 (curry Term.list_comb) gs xss; (* TODO: Eta-expension is for compatibility with the old datatype package (but it also provides nicer names). Consider removing. *) val eta_fs = map2 (fold_rev Term.lambda) xss xfs; val eta_gs = map2 (fold_rev Term.lambda) xss xgs; val case_binding = qualify false (if Binding.is_empty raw_case_binding orelse Binding.eq_name (raw_case_binding, standard_binding) then Binding.prefix_name (caseN ^ "_") fc_b else raw_case_binding); fun mk_case_disj xctr xf xs = list_exists_free xs (HOLogic.mk_conj (HOLogic.mk_eq (u, xctr), HOLogic.mk_eq (w, xf))); val case_rhs = fold_rev (fold_rev Term.lambda) [fs, [u]] (Const (\<^const_name>\The\, (B --> HOLogic.boolT) --> B) $ Term.lambda w (Library.foldr1 HOLogic.mk_disj (@{map 3} mk_case_disj xctrs xfs xss))); val ((raw_case, (_, raw_case_def)), (lthy, lthy_old)) = no_defs_lthy |> Local_Theory.open_target |> snd |> Local_Theory.define ((case_binding, NoSyn), ((Binding.concealed (Thm.def_binding case_binding), []), case_rhs)) ||> `Local_Theory.close_target; val phi = Proof_Context.export_morphism lthy_old lthy; val case_def = Morphism.thm phi raw_case_def; val case0 = Morphism.term phi raw_case; val casex = mk_case As B case0; val casexC = mk_case As C case0; val casexBool = mk_case As HOLogic.boolT case0; fun mk_uu_eq () = HOLogic.mk_eq (u, u); val exist_xs_u_eq_ctrs = map2 (fn xctr => fn xs => list_exists_free xs (HOLogic.mk_eq (u, xctr))) xctrs xss; val unique_disc_no_def = TrueI; (*arbitrary marker*) val alternate_disc_no_def = FalseE; (*arbitrary marker*) fun alternate_disc_lhs get_udisc k = HOLogic.mk_not (let val b = nth disc_bindings (k - 1) in if is_disc_binding_valid b then get_udisc b (k - 1) else nth exist_xs_u_eq_ctrs (k - 1) end); val no_discs_sels = not discs_sels andalso forall (forall Binding.is_empty) (raw_disc_bindings :: raw_sel_bindingss) andalso null sel_default_eqs; val (all_sels_distinct, discs, selss, disc_defs, sel_defs, sel_defss, lthy) = if no_discs_sels then (true, [], [], [], [], [], lthy) else let val all_sel_bindings = flat sel_bindingss; val num_all_sel_bindings = length all_sel_bindings; val uniq_sel_bindings = distinct Binding.eq_name all_sel_bindings; val all_sels_distinct = (length uniq_sel_bindings = num_all_sel_bindings); val sel_binding_index = if all_sels_distinct then 1 upto num_all_sel_bindings else map (fn b => find_index (curry Binding.eq_name b) uniq_sel_bindings) all_sel_bindings; val all_proto_sels = flat (@{map 3} (fn k => fn xs => map (pair k o pair xs)) ks xss xss); val sel_infos = AList.group (op =) (sel_binding_index ~~ all_proto_sels) |> sort (int_ord o apply2 fst) |> map snd |> curry (op ~~) uniq_sel_bindings; val sel_bindings = map fst sel_infos; val sel_defaults = if null sel_default_eqs then [] else let val sel_Ts = map (curry (op -->) fcT o fastype_of o snd o snd o hd o snd) sel_infos; val fake_lthy = fake_local_theory_for_sel_defaults (sel_bindings ~~ sel_Ts) no_defs_lthy; in map (extract_sel_default fake_lthy o prep_term fake_lthy) sel_default_eqs end; fun disc_free b = Free (Binding.name_of b, mk_pred1T fcT); fun disc_spec b exist_xs_u_eq_ctr = mk_Trueprop_eq (disc_free b $ u, exist_xs_u_eq_ctr); fun alternate_disc k = Term.lambda u (alternate_disc_lhs (K o rapp u o disc_free) (3 - k)); fun mk_sel_case_args b proto_sels T = @{map 3} (fn Const (c, _) => fn Ts => fn k => (case AList.lookup (op =) proto_sels k of NONE => (case filter (curry (op =) (c, Binding.name_of b) o fst) sel_defaults of [] => fold_rev (Term.lambda o curry Free Name.uu) Ts (mk_undefined T) | [(_, t)] => t | _ => error "Multiple default values for selector/constructor pair") | SOME (xs, x) => fold_rev Term.lambda xs x)) ctrs ctr_Tss ks; fun sel_spec b proto_sels = let val _ = (case duplicates (op =) (map fst proto_sels) of k :: _ => error ("Duplicate selector name " ^ quote (Binding.name_of b) ^ " for constructor " ^ quote (Syntax.string_of_term lthy (nth ctrs (k - 1)))) | [] => ()) val T = (case distinct (op =) (map (fastype_of o snd o snd) proto_sels) of [T] => T | T :: T' :: _ => error ("Inconsistent range type for selector " ^ quote (Binding.name_of b) ^ ": " ^ quote (Syntax.string_of_typ lthy T) ^ " vs. " ^ quote (Syntax.string_of_typ lthy T'))); in mk_Trueprop_eq (Free (Binding.name_of b, fcT --> T) $ u, Term.list_comb (mk_case As T case0, mk_sel_case_args b proto_sels T) $ u) end; fun unflat_selss xs = unflat_lookup Binding.eq_name sel_bindings xs sel_bindingss; val (((raw_discs, raw_disc_defs), (raw_sels, raw_sel_defs)), (lthy', lthy)) = lthy |> Local_Theory.open_target |> snd |> apfst split_list o @{fold_map 3} (fn k => fn exist_xs_u_eq_ctr => fn b => if Binding.is_empty b then if n = 1 then pair (Term.lambda u (mk_uu_eq ()), unique_disc_no_def) else pair (alternate_disc k, alternate_disc_no_def) else if Binding.eq_name (b, equal_binding) then pair (Term.lambda u exist_xs_u_eq_ctr, refl) else Specification.definition (SOME (b, NONE, NoSyn)) [] [] ((Thm.def_binding b, []), disc_spec b exist_xs_u_eq_ctr) #>> apsnd snd) ks exist_xs_u_eq_ctrs disc_bindings ||>> apfst split_list o fold_map (fn (b, proto_sels) => Specification.definition (SOME (b, NONE, NoSyn)) [] [] ((Thm.def_binding b, []), sel_spec b proto_sels) #>> apsnd snd) sel_infos ||> `Local_Theory.close_target; val phi = Proof_Context.export_morphism lthy lthy'; val disc_defs = map (Morphism.thm phi) raw_disc_defs; val sel_defs = map (Morphism.thm phi) raw_sel_defs; val sel_defss = unflat_selss sel_defs; val discs0 = map (Morphism.term phi) raw_discs; val selss0 = unflat_selss (map (Morphism.term phi) raw_sels); val discs = map (mk_disc_or_sel As) discs0; val selss = map (map (mk_disc_or_sel As)) selss0; in (all_sels_distinct, discs, selss, disc_defs, sel_defs, sel_defss, lthy') end; fun mk_imp_p Qs = Logic.list_implies (Qs, HOLogic.mk_Trueprop p); val exhaust_goal = let fun mk_prem xctr xs = fold_rev Logic.all xs (mk_imp_p [mk_Trueprop_eq (exh_y, xctr)]) in fold_rev Logic.all [p, exh_y] (mk_imp_p (map2 mk_prem xctrs xss)) end; val inject_goalss = let fun mk_goal _ _ [] [] = [] | mk_goal xctr yctr xs ys = [fold_rev Logic.all (xs @ ys) (mk_Trueprop_eq (HOLogic.mk_eq (xctr, yctr), Library.foldr1 HOLogic.mk_conj (map2 (curry HOLogic.mk_eq) xs ys)))]; in @{map 4} mk_goal xctrs yctrs xss yss end; val half_distinct_goalss = let fun mk_goal ((xs, xc), (xs', xc')) = fold_rev Logic.all (xs @ xs') (HOLogic.mk_Trueprop (HOLogic.mk_not (HOLogic.mk_eq (xc, xc')))); in map (map mk_goal) (mk_half_pairss (`I (xss ~~ xctrs))) end; val goalss = [exhaust_goal] :: inject_goalss @ half_distinct_goalss; fun after_qed ([exhaust_thm] :: thmss) lthy = let val ((((((((u, u'), (xss, xss')), fs), gs), h), v), p), _) = lthy |> add_bindings |> yield_singleton (apfst (op ~~) oo mk_Frees' fc_b_name) fcT ||>> mk_Freess' "x" ctr_Tss ||>> mk_Frees "f" case_Ts ||>> mk_Frees "g" case_Ts ||>> yield_singleton (mk_Frees "h") (B --> C) ||>> yield_singleton (mk_Frees (fc_b_name ^ "'")) fcT ||>> yield_singleton (mk_Frees "P") HOLogic.boolT; val xfs = map2 (curry Term.list_comb) fs xss; val xgs = map2 (curry Term.list_comb) gs xss; val fcase = Term.list_comb (casex, fs); val ufcase = fcase $ u; val vfcase = fcase $ v; val eta_fcase = Term.list_comb (casex, eta_fs); val eta_gcase = Term.list_comb (casex, eta_gs); val eta_ufcase = eta_fcase $ u; val eta_vgcase = eta_gcase $ v; fun mk_uu_eq () = HOLogic.mk_eq (u, u); val uv_eq = mk_Trueprop_eq (u, v); val ((inject_thms, inject_thmss), half_distinct_thmss) = chop n thmss |>> `flat; val rho_As = map (fn (T, U) => (dest_TVar T, Thm.ctyp_of lthy U)) (map Logic.varifyT_global As ~~ As); fun inst_thm t thm = Thm.instantiate' [] [SOME (Thm.cterm_of lthy t)] (Thm.instantiate (rho_As, []) (Drule.zero_var_indexes thm)); val uexhaust_thm = inst_thm u exhaust_thm; val exhaust_cases = map base_name_of_ctr ctrs; val other_half_distinct_thmss = map (map (fn thm => thm RS not_sym)) half_distinct_thmss; val (distinct_thms, (distinct_thmsss', distinct_thmsss)) = join_halves n half_distinct_thmss other_half_distinct_thmss ||> `transpose; val nchotomy_thm = let val goal = HOLogic.mk_Trueprop (HOLogic.mk_all (fst u', snd u', Library.foldr1 HOLogic.mk_disj exist_xs_u_eq_ctrs)); in Goal.prove_sorry lthy [] [] goal (fn {context = ctxt, prems = _} => mk_nchotomy_tac ctxt n exhaust_thm) |> Thm.close_derivation \<^here> end; val case_thms = let val goals = @{map 3} (fn xctr => fn xf => fn xs => fold_rev Logic.all (fs @ xs) (mk_Trueprop_eq (fcase $ xctr, xf))) xctrs xfs xss; in @{map 4} (fn k => fn goal => fn injects => fn distinctss => Goal.prove_sorry lthy [] [] goal (fn {context = ctxt, ...} => mk_case_tac ctxt n k case_def injects distinctss) |> Thm.close_derivation \<^here>) ks goals inject_thmss distinct_thmsss end; val (case_cong_thm, case_cong_weak_thm) = let fun mk_prem xctr xs xf xg = fold_rev Logic.all xs (Logic.mk_implies (mk_Trueprop_eq (v, xctr), mk_Trueprop_eq (xf, xg))); val goal = Logic.list_implies (uv_eq :: @{map 4} mk_prem xctrs xss xfs xgs, mk_Trueprop_eq (eta_ufcase, eta_vgcase)); val weak_goal = Logic.mk_implies (uv_eq, mk_Trueprop_eq (ufcase, vfcase)); val vars = Variable.add_free_names lthy goal []; val weak_vars = Variable.add_free_names lthy weak_goal []; in (Goal.prove_sorry lthy vars [] goal (fn {context = ctxt, prems = _} => mk_case_cong_tac ctxt uexhaust_thm case_thms), Goal.prove_sorry lthy weak_vars [] weak_goal (fn {context = ctxt, prems = _} => etac ctxt arg_cong 1)) |> apply2 (Thm.close_derivation \<^here>) end; val split_lhs = q $ ufcase; fun mk_split_conjunct xctr xs f_xs = list_all_free xs (HOLogic.mk_imp (HOLogic.mk_eq (u, xctr), q $ f_xs)); fun mk_split_disjunct xctr xs f_xs = list_exists_free xs (HOLogic.mk_conj (HOLogic.mk_eq (u, xctr), HOLogic.mk_not (q $ f_xs))); fun mk_split_goal xctrs xss xfs = mk_Trueprop_eq (split_lhs, Library.foldr1 HOLogic.mk_conj (@{map 3} mk_split_conjunct xctrs xss xfs)); fun mk_split_asm_goal xctrs xss xfs = mk_Trueprop_eq (split_lhs, HOLogic.mk_not (Library.foldr1 HOLogic.mk_disj (@{map 3} mk_split_disjunct xctrs xss xfs))); fun prove_split selss goal = Variable.add_free_names lthy goal [] |> (fn vars => Goal.prove_sorry lthy vars [] goal (fn {context = ctxt, prems = _} => mk_split_tac ctxt uexhaust_thm case_thms selss inject_thmss distinct_thmsss)) |> Thm.close_derivation \<^here>; fun prove_split_asm asm_goal split_thm = Variable.add_free_names lthy asm_goal [] |> (fn vars => Goal.prove_sorry lthy vars [] asm_goal (fn {context = ctxt, ...} => mk_split_asm_tac ctxt split_thm)) |> Thm.close_derivation \<^here>; val (split_thm, split_asm_thm) = let val goal = mk_split_goal xctrs xss xfs; val asm_goal = mk_split_asm_goal xctrs xss xfs; val thm = prove_split (replicate n []) goal; val asm_thm = prove_split_asm asm_goal thm; in (thm, asm_thm) end; val (sel_defs, all_sel_thms, sel_thmss, nontriv_disc_defs, disc_thmss, nontriv_disc_thmss, discI_thms, nontriv_discI_thms, distinct_disc_thms, distinct_disc_thmsss, exhaust_disc_thms, exhaust_sel_thms, all_collapse_thms, safe_collapse_thms, expand_thms, split_sel_thms, split_sel_asm_thms, case_eq_if_thms, disc_eq_case_thms) = if no_discs_sels then ([], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], []) else let val udiscs = map (rapp u) discs; val uselss = map (map (rapp u)) selss; val usel_ctrs = map2 (curry Term.list_comb) ctrs uselss; val usel_fs = map2 (curry Term.list_comb) fs uselss; val vdiscs = map (rapp v) discs; val vselss = map (map (rapp v)) selss; fun make_sel_thm xs' case_thm sel_def = zero_var_indexes (Variable.gen_all lthy (Drule.rename_bvars' (map (SOME o fst) xs') (Drule.forall_intr_vars (case_thm RS (sel_def RS trans))))); val sel_thmss = @{map 3} (map oo make_sel_thm) xss' case_thms sel_defss; fun has_undefined_rhs thm = (case snd (HOLogic.dest_eq (HOLogic.dest_Trueprop (Thm.prop_of thm))) of Const (\<^const_name>\undefined\, _) => true | _ => false); val all_sel_thms = (if all_sels_distinct andalso null sel_default_eqs then flat sel_thmss else map_product (fn s => fn (xs', c) => make_sel_thm xs' c s) sel_defs (xss' ~~ case_thms)) |> filter_out has_undefined_rhs; fun mk_unique_disc_def () = let val m = the_single ms; val goal = mk_Trueprop_eq (mk_uu_eq (), the_single exist_xs_u_eq_ctrs); val vars = Variable.add_free_names lthy goal []; in Goal.prove_sorry lthy vars [] goal (fn {context = ctxt, prems = _} => mk_unique_disc_def_tac ctxt m uexhaust_thm) |> Thm.close_derivation \<^here> end; fun mk_alternate_disc_def k = let val goal = mk_Trueprop_eq (alternate_disc_lhs (K (nth udiscs)) (3 - k), nth exist_xs_u_eq_ctrs (k - 1)); val vars = Variable.add_free_names lthy goal []; in Goal.prove_sorry lthy vars [] goal (fn {context = ctxt, ...} => mk_alternate_disc_def_tac ctxt k (nth disc_defs (2 - k)) (nth distinct_thms (2 - k)) uexhaust_thm) |> Thm.close_derivation \<^here> end; val has_alternate_disc_def = exists (fn def => Thm.eq_thm_prop (def, alternate_disc_no_def)) disc_defs; val nontriv_disc_defs = disc_defs |> filter_out (member Thm.eq_thm_prop [unique_disc_no_def, alternate_disc_no_def, refl]); val disc_defs' = map2 (fn k => fn def => if Thm.eq_thm_prop (def, unique_disc_no_def) then mk_unique_disc_def () else if Thm.eq_thm_prop (def, alternate_disc_no_def) then mk_alternate_disc_def k else def) ks disc_defs; val discD_thms = map (fn def => def RS iffD1) disc_defs'; val discI_thms = map2 (fn m => fn def => funpow m (fn thm => exI RS thm) (def RS iffD2)) ms disc_defs'; val not_discI_thms = map2 (fn m => fn def => funpow m (fn thm => allI RS thm) (unfold_thms lthy @{thms not_ex} (def RS @{thm ssubst[of _ _ Not]}))) ms disc_defs'; val (disc_thmss', disc_thmss) = let fun mk_thm discI _ [] = refl RS discI | mk_thm _ not_discI [distinct] = distinct RS not_discI; fun mk_thms discI not_discI distinctss = map (mk_thm discI not_discI) distinctss; in @{map 3} mk_thms discI_thms not_discI_thms distinct_thmsss' |> `transpose end; val nontriv_disc_thmss = map2 (fn b => if is_disc_binding_valid b then I else K []) disc_bindings disc_thmss; fun is_discI_triv b = (n = 1 andalso Binding.is_empty b) orelse Binding.eq_name (b, equal_binding); val nontriv_discI_thms = flat (map2 (fn b => if is_discI_triv b then K [] else single) disc_bindings discI_thms); val (distinct_disc_thms, (distinct_disc_thmsss', distinct_disc_thmsss)) = let fun mk_goal [] = [] | mk_goal [((_, udisc), (_, udisc'))] = [Logic.all u (Logic.mk_implies (HOLogic.mk_Trueprop udisc, HOLogic.mk_Trueprop (HOLogic.mk_not udisc')))]; fun prove tac goal = Goal.prove_sorry lthy [] [] goal (fn {context = ctxt, prems = _} => tac ctxt) |> Thm.close_derivation \<^here>; val half_pairss = mk_half_pairss (`I (ms ~~ discD_thms ~~ udiscs)); val half_goalss = map mk_goal half_pairss; val half_thmss = @{map 3} (fn [] => K (K []) | [goal] => fn [(((m, discD), _), _)] => fn disc_thm => [prove (fn ctxt => mk_half_distinct_disc_tac ctxt m discD disc_thm) goal]) half_goalss half_pairss (flat disc_thmss'); val other_half_goalss = map (mk_goal o map swap) half_pairss; val other_half_thmss = map2 (map2 (fn thm => prove (fn ctxt => mk_other_half_distinct_disc_tac ctxt thm))) half_thmss other_half_goalss; in join_halves n half_thmss other_half_thmss ||> `transpose |>> has_alternate_disc_def ? K [] end; val exhaust_disc_thm = let fun mk_prem udisc = mk_imp_p [HOLogic.mk_Trueprop udisc]; val goal = fold_rev Logic.all [p, u] (mk_imp_p (map mk_prem udiscs)); in Goal.prove_sorry lthy [] [] goal (fn {context = ctxt, prems = _} => mk_exhaust_disc_tac ctxt n exhaust_thm discI_thms) |> Thm.close_derivation \<^here> end; val (safe_collapse_thms, all_collapse_thms) = let fun mk_goal m udisc usel_ctr = let val prem = HOLogic.mk_Trueprop udisc; val concl = mk_Trueprop_eq ((usel_ctr, u) |> m = 0 ? swap); in (prem aconv concl, Logic.all u (Logic.mk_implies (prem, concl))) end; val (trivs, goals) = @{map 3} mk_goal ms udiscs usel_ctrs |> split_list; val thms = @{map 5} (fn m => fn discD => fn sel_thms => fn triv => fn goal => Goal.prove_sorry lthy [] [] goal (fn {context = ctxt, ...} => mk_collapse_tac ctxt m discD sel_thms ORELSE HEADGOAL (assume_tac ctxt)) |> Thm.close_derivation \<^here> |> not triv ? perhaps (try (fn thm => refl RS thm))) ms discD_thms sel_thmss trivs goals; in (map_filter (fn (true, _) => NONE | (false, thm) => SOME thm) (trivs ~~ thms), thms) end; val swapped_all_collapse_thms = map2 (fn m => fn thm => if m = 0 then thm else thm RS sym) ms all_collapse_thms; val exhaust_sel_thm = let fun mk_prem usel_ctr = mk_imp_p [mk_Trueprop_eq (u, usel_ctr)]; val goal = fold_rev Logic.all [p, u] (mk_imp_p (map mk_prem usel_ctrs)); in Goal.prove_sorry lthy [] [] goal (fn {context = ctxt, prems = _} => mk_exhaust_sel_tac ctxt n exhaust_disc_thm swapped_all_collapse_thms) |> Thm.close_derivation \<^here> end; val expand_thm = let fun mk_prems k udisc usels vdisc vsels = (if k = n then [] else [mk_Trueprop_eq (udisc, vdisc)]) @ (if null usels then [] else [Logic.list_implies (if n = 1 then [] else map HOLogic.mk_Trueprop [udisc, vdisc], HOLogic.mk_Trueprop (Library.foldr1 HOLogic.mk_conj (map2 (curry HOLogic.mk_eq) usels vsels)))]); val goal = Library.foldr Logic.list_implies (@{map 5} mk_prems ks udiscs uselss vdiscs vselss, uv_eq); val uncollapse_thms = map2 (fn thm => fn [] => thm | _ => thm RS sym) all_collapse_thms uselss; val vars = Variable.add_free_names lthy goal []; in Goal.prove_sorry lthy vars [] goal (fn {context = ctxt, prems = _} => mk_expand_tac ctxt n ms (inst_thm u exhaust_disc_thm) (inst_thm v exhaust_disc_thm) uncollapse_thms distinct_disc_thmsss distinct_disc_thmsss') |> Thm.close_derivation \<^here> end; val (split_sel_thm, split_sel_asm_thm) = let val zss = map (K []) xss; val goal = mk_split_goal usel_ctrs zss usel_fs; val asm_goal = mk_split_asm_goal usel_ctrs zss usel_fs; val thm = prove_split sel_thmss goal; val asm_thm = prove_split_asm asm_goal thm; in (thm, asm_thm) end; val case_eq_if_thm = let val goal = mk_Trueprop_eq (ufcase, mk_IfN B udiscs usel_fs); val vars = Variable.add_free_names lthy goal []; in Goal.prove_sorry lthy vars [] goal (fn {context = ctxt, ...} => mk_case_eq_if_tac ctxt n uexhaust_thm case_thms disc_thmss' sel_thmss) |> Thm.close_derivation \<^here> end; val disc_eq_case_thms = let fun const_of_bool b = if b then \<^const>\True\ else \<^const>\False\; fun mk_case_args n = map_index (fn (k, argTs) => fold_rev Term.absdummy argTs (const_of_bool (n = k))) ctr_Tss; val goals = map_index (fn (n, udisc) => mk_Trueprop_eq (udisc, list_comb (casexBool, mk_case_args n) $ u)) udiscs; val goal = Logic.mk_conjunction_balanced goals; val vars = Variable.add_free_names lthy goal []; in Goal.prove_sorry lthy vars [] goal (fn {context = ctxt, ...} => mk_disc_eq_case_tac ctxt (Thm.cterm_of ctxt u) exhaust_thm (flat nontriv_disc_thmss) distinct_thms case_thms) |> Thm.close_derivation \<^here> |> Conjunction.elim_balanced (length goals) end; in (sel_defs, all_sel_thms, sel_thmss, nontriv_disc_defs, disc_thmss, nontriv_disc_thmss, discI_thms, nontriv_discI_thms, distinct_disc_thms, distinct_disc_thmsss, [exhaust_disc_thm], [exhaust_sel_thm], all_collapse_thms, safe_collapse_thms, [expand_thm], [split_sel_thm], [split_sel_asm_thm], [case_eq_if_thm], disc_eq_case_thms) end; val case_distrib_thm = let val args = @{map 2} (fn f => fn argTs => let val (args, _) = mk_Frees "x" argTs lthy in fold_rev Term.lambda args (h $ list_comb (f, args)) end) fs ctr_Tss; val goal = mk_Trueprop_eq (h $ ufcase, list_comb (casexC, args) $ u); val vars = Variable.add_free_names lthy goal []; in Goal.prove_sorry lthy vars [] goal (fn {context = ctxt, ...} => mk_case_distrib_tac ctxt (Thm.cterm_of ctxt u) exhaust_thm case_thms) |> Thm.close_derivation \<^here> end; val exhaust_case_names_attr = Attrib.internal (K (Rule_Cases.case_names exhaust_cases)); val cases_type_attr = Attrib.internal (K (Induct.cases_type fcT_name)); val nontriv_disc_eq_thmss = map (map (fn th => th RS @{thm eq_False[THEN iffD2]} handle THM _ => th RS @{thm eq_True[THEN iffD2]})) nontriv_disc_thmss; val anonymous_notes = [(map (fn th => th RS notE) distinct_thms, safe_elim_attrs), (flat nontriv_disc_eq_thmss, nitpicksimp_attrs)] |> map (fn (thms, attrs) => ((Binding.empty, attrs), [(thms, [])])); val notes = [(caseN, case_thms, nitpicksimp_attrs @ simp_attrs), (case_congN, [case_cong_thm], []), (case_cong_weak_thmsN, [case_cong_weak_thm], cong_attrs), (case_distribN, [case_distrib_thm], []), (case_eq_ifN, case_eq_if_thms, []), (collapseN, safe_collapse_thms, if ms = [0] then [] else simp_attrs), (discN, flat nontriv_disc_thmss, simp_attrs), (disc_eq_caseN, disc_eq_case_thms, []), (discIN, nontriv_discI_thms, []), (distinctN, distinct_thms, simp_attrs @ inductsimp_attrs), (distinct_discN, distinct_disc_thms, dest_attrs), (exhaustN, [exhaust_thm], [exhaust_case_names_attr, cases_type_attr]), (exhaust_discN, exhaust_disc_thms, [exhaust_case_names_attr]), (exhaust_selN, exhaust_sel_thms, [exhaust_case_names_attr]), (expandN, expand_thms, []), (injectN, inject_thms, iff_attrs @ inductsimp_attrs), (nchotomyN, [nchotomy_thm], []), (selN, all_sel_thms, nitpicksimp_attrs @ simp_attrs), (splitN, [split_thm], []), (split_asmN, [split_asm_thm], []), (split_selN, split_sel_thms, []), (split_sel_asmN, split_sel_asm_thms, []), (split_selsN, split_sel_thms @ split_sel_asm_thms, []), (splitsN, [split_thm, split_asm_thm], [])] |> filter_out (null o #2) |> map (fn (thmN, thms, attrs) => ((qualify true (Binding.name thmN), attrs), [(thms, [])])); val (noted, lthy') = lthy |> Spec_Rules.add Binding.empty Spec_Rules.equational [casex] case_thms |> fold (uncurry (Spec_Rules.add Binding.empty Spec_Rules.equational)) (AList.group (eq_list (op aconv)) (map (`(single o lhs_head_of)) all_sel_thms)) |> fold (uncurry (Spec_Rules.add Binding.empty Spec_Rules.equational)) (filter_out (null o snd) (map single discs ~~ nontriv_disc_eq_thmss)) |> Local_Theory.declaration {syntax = false, pervasive = true} (fn phi => Case_Translation.register (Morphism.term phi casex) (map (Morphism.term phi) ctrs)) |> plugins code_plugin ? (Code.declare_default_eqns (map (rpair true) (flat nontriv_disc_eq_thmss @ case_thms @ all_sel_thms)) #> Local_Theory.declaration {syntax = false, pervasive = false} (fn phi => Context.mapping (add_ctr_code fcT_name (map (Morphism.typ phi) As) (map (dest_Const o Morphism.term phi) ctrs) (Morphism.fact phi inject_thms) (Morphism.fact phi distinct_thms) (Morphism.fact phi case_thms)) I)) |> Local_Theory.notes (anonymous_notes @ notes) (* for "datatype_realizer.ML": *) |>> name_noted_thms fcT_name exhaustN; val ctr_sugar = {kind = kind, T = fcT, ctrs = ctrs, casex = casex, discs = discs, selss = selss, exhaust = exhaust_thm, nchotomy = nchotomy_thm, injects = inject_thms, distincts = distinct_thms, case_thms = case_thms, case_cong = case_cong_thm, case_cong_weak = case_cong_weak_thm, case_distribs = [case_distrib_thm], split = split_thm, split_asm = split_asm_thm, disc_defs = nontriv_disc_defs, disc_thmss = disc_thmss, discIs = discI_thms, disc_eq_cases = disc_eq_case_thms, sel_defs = sel_defs, sel_thmss = sel_thmss, distinct_discsss = distinct_disc_thmsss, exhaust_discs = exhaust_disc_thms, exhaust_sels = exhaust_sel_thms, collapses = all_collapse_thms, expands = expand_thms, split_sels = split_sel_thms, split_sel_asms = split_sel_asm_thms, case_eq_ifs = case_eq_if_thms} |> morph_ctr_sugar (substitute_noted_thm noted); in (ctr_sugar, lthy' |> register_ctr_sugar plugins ctr_sugar) end; in (goalss, after_qed, lthy) end; fun free_constructors kind tacss = (fn (goalss, after_qed, lthy) => map2 (map2 (Thm.close_derivation \<^here> oo Goal.prove_sorry lthy [] [])) goalss tacss |> (fn thms => after_qed thms lthy)) oo prepare_free_constructors kind (K I) (K I); fun free_constructors_cmd kind = (fn (goalss, after_qed, lthy) => Proof.theorem NONE (snd oo after_qed) (map (map (rpair [])) goalss) lthy) oo prepare_free_constructors kind Plugin_Name.make_filter Syntax.read_term; val parse_bound_term = Parse.binding --| \<^keyword>\:\ -- Parse.term; type ctr_options = Plugin_Name.filter * bool; type ctr_options_cmd = (Proof.context -> Plugin_Name.filter) * bool; val default_ctr_options : ctr_options = (Plugin_Name.default_filter, false); val default_ctr_options_cmd : ctr_options_cmd = (K Plugin_Name.default_filter, false); val parse_ctr_options = Scan.optional (\<^keyword>\(\ |-- Parse.list1 (Plugin_Name.parse_filter >> (apfst o K) || Parse.reserved "discs_sels" >> (apsnd o K o K true)) --| \<^keyword>\)\ >> (fn fs => fold I fs default_ctr_options_cmd)) default_ctr_options_cmd; fun parse_ctr_spec parse_ctr parse_arg = parse_opt_binding_colon -- parse_ctr -- Scan.repeat parse_arg; val parse_ctr_specs = Parse.enum1 "|" (parse_ctr_spec Parse.term Parse.binding); val parse_sel_default_eqs = Scan.optional (\<^keyword>\where\ |-- Parse.enum1 "|" Parse.prop) []; val _ = Outer_Syntax.local_theory_to_proof \<^command_keyword>\free_constructors\ "register an existing freely generated type's constructors" (parse_ctr_options -- Parse.binding --| \<^keyword>\for\ -- parse_ctr_specs -- parse_sel_default_eqs >> free_constructors_cmd Unknown); + + +(** document antiquotations **) + +local + +fun antiquote_setup binding co = + Thy_Output.antiquotation_pretty_source_embedded binding + ((Scan.ahead (Scan.lift Parse.not_eof) >> Token.pos_of) -- + Args.type_name {proper = true, strict = true}) + (fn ctxt => fn (pos, type_name) => + let + fun err () = + error ("Bad " ^ Binding.name_of binding ^ ": " ^ quote type_name ^ Position.here pos); + in + (case ctr_sugar_of ctxt type_name of + NONE => err () + | SOME {kind, T = T0, ctrs = ctrs0, ...} => + let + val _ = if co = (kind = Codatatype) then () else err (); + + val T = Logic.unvarifyT_global T0; + val ctrs = map Logic.unvarify_global ctrs0; + + val pretty_typ_bracket = Syntax.pretty_typ (Config.put pretty_priority 1001 ctxt); + fun pretty_ctr ctr = + Pretty.block (Pretty.breaks (Syntax.pretty_term ctxt ctr :: + map pretty_typ_bracket (binder_types (fastype_of ctr)))); + in + Pretty.block (Pretty.keyword1 (Binding.name_of binding) :: Pretty.brk 1 :: + Syntax.pretty_typ ctxt T :: Pretty.str " =" :: Pretty.brk 1 :: + flat (separate [Pretty.brk 1, Pretty.str "| "] (map (single o pretty_ctr) ctrs))) + end) + end); + +in + +val _ = + Theory.setup + (antiquote_setup \<^binding>\datatype\ false #> + antiquote_setup \<^binding>\codatatype\ true); + end; + +end;