diff --git a/src/HOL/Tools/record.ML b/src/HOL/Tools/record.ML --- a/src/HOL/Tools/record.ML +++ b/src/HOL/Tools/record.ML @@ -1,2454 +1,2453 @@ (* Title: HOL/Tools/record.ML Author: Wolfgang Naraschewski, TU Muenchen Author: Markus Wenzel, TU Muenchen Author: Norbert Schirmer, TU Muenchen Author: Norbert Schirmer, Apple, 2022 Author: Thomas Sewell, NICTA Extensible records with structural subtyping. *) signature RECORD = sig val type_abbr: bool Config.T val type_as_fields: bool Config.T val timing: bool Config.T type info = {args: (string * sort) list, parent: (typ list * string) option, fields: (string * typ) list, extension: (string * typ list), ext_induct: thm, ext_inject: thm, ext_surjective: thm, ext_split: thm, ext_def: thm, select_convs: thm list, update_convs: thm list, select_defs: thm list, update_defs: thm list, fold_congs: thm list, unfold_congs: thm list, splits: thm list, defs: thm list, surjective: thm, equality: thm, induct_scheme: thm, induct: thm, cases_scheme: thm, cases: thm, simps: thm list, iffs: thm list} val get_info: theory -> string -> info option val the_info: theory -> string -> info val get_hierarchy: theory -> (string * typ list) -> (string * ((string * sort) * typ) list) list val add_record: {overloaded: bool} -> (string * sort) list * binding -> (typ list * string) option -> (binding * typ * mixfix) list -> theory -> theory val last_extT: typ -> (string * typ list) option val dest_recTs: typ -> (string * typ list) list val get_extT_fields: theory -> typ -> (string * typ) list * (string * typ) val get_recT_fields: theory -> typ -> (string * typ) list * (string * typ) val get_parent: theory -> string -> (typ list * string) option val get_extension: theory -> string -> (string * typ list) option val get_extinjects: theory -> thm list val get_simpset: theory -> simpset val simproc: simproc val eq_simproc: simproc val upd_simproc: simproc val split_simproc: (term -> int) -> simproc val ex_sel_eq_simproc: simproc val split_tac: Proof.context -> int -> tactic val split_simp_tac: Proof.context -> thm list -> (term -> int) -> int -> tactic val split_wrapper: string * (Proof.context -> wrapper) val pretty_recT: Proof.context -> typ -> Pretty.T val string_of_record: Proof.context -> string -> string val codegen: bool Config.T val sort_updates: bool Config.T val updateN: string val ext_typeN: string val extN: string end; signature ISO_TUPLE_SUPPORT = sig val add_iso_tuple_type: {overloaded: bool} -> binding * (string * sort) list -> typ * typ -> theory -> (term * term) * theory val mk_cons_tuple: term * term -> term val dest_cons_tuple: term -> term * term val iso_tuple_intros_tac: Proof.context -> int -> tactic end; structure Iso_Tuple_Support: ISO_TUPLE_SUPPORT = struct val isoN = "_Tuple_Iso"; val iso_tuple_intro = @{thm isomorphic_tuple_intro}; val iso_tuple_intros = Tactic.build_net @{thms isomorphic_tuple.intros}; val tuple_iso_tuple = (\<^const_name>\Record.tuple_iso_tuple\, @{thm tuple_iso_tuple}); structure Iso_Tuple_Thms = Theory_Data ( type T = thm Symtab.table; val empty = Symtab.make [tuple_iso_tuple]; fun merge data = Symtab.merge Thm.eq_thm_prop data; (* FIXME handle Symtab.DUP ?? *) ); fun get_typedef_info tyco vs (({rep_type, Abs_name, ...}, {Rep_inject, Abs_inverse, ... }) : Typedef.info) thy = let val exists_thm = UNIV_I |> Thm.instantiate' [SOME (Thm.global_ctyp_of thy (Logic.varifyT_global rep_type))] []; val proj_constr = Abs_inverse OF [exists_thm]; val absT = Type (tyco, map TFree vs); in thy |> pair (tyco, ((Rep_inject, proj_constr), Const (Abs_name, rep_type --> absT), absT)) end fun do_typedef overloaded raw_tyco repT raw_vs thy = let val ctxt = Proof_Context.init_global thy |> Variable.declare_typ repT; val vs = map (Proof_Context.check_tfree ctxt) raw_vs; in thy |> Named_Target.theory_map_result (apsnd o Typedef.transform_info) (Typedef.add_typedef overloaded (raw_tyco, vs, NoSyn) (HOLogic.mk_UNIV repT) NONE (fn ctxt' => resolve_tac ctxt' [UNIV_witness] 1)) |-> (fn (tyco, info) => get_typedef_info tyco vs info) end; fun mk_cons_tuple (t, u) = let val (A, B) = apply2 fastype_of (t, u) in \<^Const>\iso_tuple_cons \<^Type>\prod A B\ A B for \<^Const>\tuple_iso_tuple A B\ t u\ end; fun dest_cons_tuple \<^Const_>\iso_tuple_cons _ _ _ for \Const _\ t u\ = (t, u) | dest_cons_tuple t = raise TERM ("dest_cons_tuple", [t]); fun add_iso_tuple_type overloaded (b, alphas) (leftT, rightT) thy = let val repT = HOLogic.mk_prodT (leftT, rightT); val ((_, ((rep_inject, abs_inverse), absC, absT)), typ_thy) = thy |> do_typedef overloaded b repT alphas ||> Sign.add_path (Binding.name_of b); (*FIXME proper prefixing instead*) val typ_ctxt = Proof_Context.init_global typ_thy; (*construct a type and body for the isomorphism constant by instantiating the theorem to which the definition will be applied*) val intro_inst = rep_inject RS infer_instantiate typ_ctxt [(("abst", 0), Thm.cterm_of typ_ctxt absC)] iso_tuple_intro; val (_, body) = Logic.dest_equals (List.last (Thm.prems_of intro_inst)); val isomT = fastype_of body; val isom_binding = Binding.suffix_name isoN b; val isom_name = Sign.full_name typ_thy isom_binding; val isom = Const (isom_name, isomT); val ([isom_def], cdef_thy) = typ_thy |> Sign.declare_const_global ((isom_binding, isomT), NoSyn) |> snd |> Global_Theory.add_defs false [((Binding.concealed (Thm.def_binding isom_binding), Logic.mk_equals (isom, body)), [])]; val iso_tuple = isom_def RS (abs_inverse RS (rep_inject RS iso_tuple_intro)); val cons = \<^Const>\iso_tuple_cons absT leftT rightT\; val thm_thy = cdef_thy |> Iso_Tuple_Thms.map (Symtab.insert Thm.eq_thm_prop (isom_name, iso_tuple)) |> Sign.restore_naming thy in ((isom, cons $ isom), thm_thy) end; fun iso_tuple_intros_tac ctxt = resolve_from_net_tac ctxt iso_tuple_intros THEN' CSUBGOAL (fn (cgoal, i) => let val goal = Thm.term_of cgoal; val isthms = Iso_Tuple_Thms.get (Proof_Context.theory_of ctxt); fun err s t = raise TERM ("iso_tuple_intros_tac: " ^ s, [t]); val goal' = Envir.beta_eta_contract goal; val is = (case goal' of \<^Const_>\Trueprop for \<^Const>\isomorphic_tuple _ _ _ for \Const is\\\ => is | _ => err "unexpected goal format" goal'); val isthm = (case Symtab.lookup isthms (#1 is) of SOME isthm => isthm | NONE => err "no thm found for constant" (Const is)); in resolve_tac ctxt [isthm] i end); end; structure Record: RECORD = struct val surject_assistI = @{thm iso_tuple_surjective_proof_assistI}; val surject_assist_idE = @{thm iso_tuple_surjective_proof_assist_idE}; val updacc_accessor_eqE = @{thm update_accessor_accessor_eqE}; val updacc_updator_eqE = @{thm update_accessor_updator_eqE}; val updacc_eq_idI = @{thm iso_tuple_update_accessor_eq_assist_idI}; val updacc_eq_triv = @{thm iso_tuple_update_accessor_eq_assist_triv}; val updacc_foldE = @{thm update_accessor_congruence_foldE}; val updacc_unfoldE = @{thm update_accessor_congruence_unfoldE}; val updacc_noopE = @{thm update_accessor_noopE}; val updacc_noop_compE = @{thm update_accessor_noop_compE}; val updacc_cong_idI = @{thm update_accessor_cong_assist_idI}; val updacc_cong_triv = @{thm update_accessor_cong_assist_triv}; val updacc_cong_from_eq = @{thm iso_tuple_update_accessor_cong_from_eq}; val codegen = Attrib.setup_config_bool \<^binding>\record_codegen\ (K true); val sort_updates = Attrib.setup_config_bool \<^binding>\record_sort_updates\ (K false); (** name components **) val rN = "r"; val wN = "w"; val moreN = "more"; val schemeN = "_scheme"; val ext_typeN = "_ext"; val inner_typeN = "_inner"; val extN ="_ext"; val updateN = "_update"; val makeN = "make"; val fields_selN = "fields"; val extendN = "extend"; val truncateN = "truncate"; (*** utilities ***) fun varifyT idx = map_type_tfree (fn (a, S) => TVar ((a, idx), S)); (* timing *) val timing = Attrib.setup_config_bool \<^binding>\record_timing\ (K false); fun timeit_msg ctxt s x = if Config.get ctxt timing then (warning s; timeit x) else x (); fun timing_msg ctxt s = if Config.get ctxt timing then warning s else (); (* syntax *) val Trueprop = HOLogic.mk_Trueprop; infix 0 :== ===; infixr 0 ==>; val op :== = Misc_Legacy.mk_defpair; val op === = Trueprop o HOLogic.mk_eq; val op ==> = Logic.mk_implies; (* constructor *) fun mk_ext (name, T) ts = let val Ts = map fastype_of ts in list_comb (Const (suffix extN name, Ts ---> T), ts) end; (* selector *) fun mk_selC sT (c, T) = (c, sT --> T); fun mk_sel s (c, T) = let val sT = fastype_of s in Const (mk_selC sT (c, T)) $ s end; (* updates *) fun mk_updC sfx sT (c, T) = (suffix sfx c, (T --> T) --> sT --> sT); fun mk_upd' sfx c v sT = let val vT = domain_type (fastype_of v); in Const (mk_updC sfx sT (c, vT)) $ v end; fun mk_upd sfx c v s = mk_upd' sfx c v (fastype_of s) $ s; (* types *) fun dest_recT (typ as Type (c_ext_type, Ts as (_ :: _))) = (case try (unsuffix ext_typeN) c_ext_type of NONE => raise TYPE ("Record.dest_recT", [typ], []) | SOME c => ((c, Ts), List.last Ts)) | dest_recT typ = raise TYPE ("Record.dest_recT", [typ], []); val is_recT = can dest_recT; fun dest_recTs T = let val ((c, Ts), U) = dest_recT T in (c, Ts) :: dest_recTs U end handle TYPE _ => []; fun last_extT T = let val ((c, Ts), U) = dest_recT T in (case last_extT U of NONE => SOME (c, Ts) | SOME l => SOME l) end handle TYPE _ => NONE; fun rec_id i T = let val rTs = dest_recTs T; val rTs' = if i < 0 then rTs else take i rTs; in implode (map #1 rTs') end; (*** extend theory by record definition ***) (** record info **) (* type info and parent_info *) type info = {args: (string * sort) list, parent: (typ list * string) option, fields: (string * typ) list, extension: (string * typ list), ext_induct: thm, ext_inject: thm, ext_surjective: thm, ext_split: thm, ext_def: thm, select_convs: thm list, update_convs: thm list, select_defs: thm list, update_defs: thm list, fold_congs: thm list, (* potentially used in L4.verified *) unfold_congs: thm list, (* potentially used in L4.verified *) splits: thm list, defs: thm list, surjective: thm, equality: thm, induct_scheme: thm, induct: thm, cases_scheme: thm, cases: thm, simps: thm list, iffs: thm list}; fun make_info args parent fields extension ext_induct ext_inject ext_surjective ext_split ext_def select_convs update_convs select_defs update_defs fold_congs unfold_congs splits defs surjective equality induct_scheme induct cases_scheme cases simps iffs : info = {args = args, parent = parent, fields = fields, extension = extension, ext_induct = ext_induct, ext_inject = ext_inject, ext_surjective = ext_surjective, ext_split = ext_split, ext_def = ext_def, select_convs = select_convs, update_convs = update_convs, select_defs = select_defs, update_defs = update_defs, fold_congs = fold_congs, unfold_congs = unfold_congs, splits = splits, defs = defs, surjective = surjective, equality = equality, induct_scheme = induct_scheme, induct = induct, cases_scheme = cases_scheme, cases = cases, simps = simps, iffs = iffs}; type parent_info = {name: string, fields: (string * typ) list, extension: (string * typ list), induct_scheme: thm, ext_def: thm}; fun make_parent_info name fields extension ext_def induct_scheme : parent_info = {name = name, fields = fields, extension = extension, ext_def = ext_def, induct_scheme = induct_scheme}; (* theory data *) type data = {records: info Symtab.table, sel_upd: {selectors: (int * bool) Symtab.table, updates: string Symtab.table, simpset: simpset, defset: simpset}, equalities: thm Symtab.table, extinjects: thm list, extsplit: thm Symtab.table, (*maps extension name to split rule*) splits: (thm * thm * thm * thm) Symtab.table, (*!!, ALL, EX - split-equalities, induct rule*) extfields: (string * typ) list Symtab.table, (*maps extension to its fields*) fieldext: (string * typ list) Symtab.table}; (*maps field to its extension*) fun make_data records sel_upd equalities extinjects extsplit splits extfields fieldext = {records = records, sel_upd = sel_upd, equalities = equalities, extinjects=extinjects, extsplit = extsplit, splits = splits, extfields = extfields, fieldext = fieldext }: data; structure Data = Theory_Data ( type T = data; val empty = make_data Symtab.empty {selectors = Symtab.empty, updates = Symtab.empty, simpset = HOL_basic_ss, defset = HOL_basic_ss} Symtab.empty [] Symtab.empty Symtab.empty Symtab.empty Symtab.empty; fun merge ({records = recs1, sel_upd = {selectors = sels1, updates = upds1, simpset = ss1, defset = ds1}, equalities = equalities1, extinjects = extinjects1, extsplit = extsplit1, splits = splits1, extfields = extfields1, fieldext = fieldext1}, {records = recs2, sel_upd = {selectors = sels2, updates = upds2, simpset = ss2, defset = ds2}, equalities = equalities2, extinjects = extinjects2, extsplit = extsplit2, splits = splits2, extfields = extfields2, fieldext = fieldext2}) = make_data (Symtab.merge (K true) (recs1, recs2)) {selectors = Symtab.merge (K true) (sels1, sels2), updates = Symtab.merge (K true) (upds1, upds2), simpset = Simplifier.merge_ss (ss1, ss2), defset = Simplifier.merge_ss (ds1, ds2)} (Symtab.merge Thm.eq_thm_prop (equalities1, equalities2)) (Thm.merge_thms (extinjects1, extinjects2)) (Symtab.merge Thm.eq_thm_prop (extsplit1, extsplit2)) (Symtab.merge (fn ((a, b, c, d), (w, x, y, z)) => Thm.eq_thm (a, w) andalso Thm.eq_thm (b, x) andalso Thm.eq_thm (c, y) andalso Thm.eq_thm (d, z)) (splits1, splits2)) (Symtab.merge (K true) (extfields1, extfields2)) (Symtab.merge (K true) (fieldext1, fieldext2)); ); (* access 'records' *) val get_info = Symtab.lookup o #records o Data.get; fun the_info thy name = (case get_info thy name of SOME info => info | NONE => error ("Unknown record type " ^ quote name)); fun put_record name info = Data.map (fn {records, sel_upd, equalities, extinjects, extsplit, splits, extfields, fieldext} => make_data (Symtab.update (name, info) records) sel_upd equalities extinjects extsplit splits extfields fieldext); (* access 'sel_upd' *) val get_sel_upd = #sel_upd o Data.get; val is_selector = Symtab.defined o #selectors o get_sel_upd; val get_updates = Symtab.lookup o #updates o get_sel_upd; val get_simpset = #simpset o get_sel_upd; val get_sel_upd_defs = #defset o get_sel_upd; fun get_update_details u thy = let val sel_upd = get_sel_upd thy in (case Symtab.lookup (#updates sel_upd) u of SOME s => let val SOME (dep, ismore) = Symtab.lookup (#selectors sel_upd) s in SOME (s, dep, ismore) end | NONE => NONE) end; fun put_sel_upd names more depth simps defs thy = let val ctxt0 = Proof_Context.init_global thy; val all = names @ [more]; val sels = map (rpair (depth, false)) names @ [(more, (depth, true))]; val upds = map (suffix updateN) all ~~ all; val {records, sel_upd = {selectors, updates, simpset, defset}, equalities, extinjects, extsplit, splits, extfields, fieldext} = Data.get thy; val data = make_data records {selectors = fold Symtab.update_new sels selectors, updates = fold Symtab.update_new upds updates, simpset = simpset_map ctxt0 (fn ctxt => ctxt addsimps simps) simpset, defset = simpset_map ctxt0 (fn ctxt => ctxt addsimps defs) defset} equalities extinjects extsplit splits extfields fieldext; in Data.put data thy end; (* access 'equalities' *) fun add_equalities name thm = Data.map (fn {records, sel_upd, equalities, extinjects, extsplit, splits, extfields, fieldext} => make_data records sel_upd (Symtab.update_new (name, thm) equalities) extinjects extsplit splits extfields fieldext); val get_equalities = Symtab.lookup o #equalities o Data.get; (* access 'extinjects' *) fun add_extinjects thm = Data.map (fn {records, sel_upd, equalities, extinjects, extsplit, splits, extfields, fieldext} => make_data records sel_upd equalities (insert Thm.eq_thm_prop thm extinjects) extsplit splits extfields fieldext); val get_extinjects = rev o #extinjects o Data.get; (* access 'extsplit' *) fun add_extsplit name thm = Data.map (fn {records, sel_upd, equalities, extinjects, extsplit, splits, extfields, fieldext} => make_data records sel_upd equalities extinjects (Symtab.update_new (name, thm) extsplit) splits extfields fieldext); (* access 'splits' *) fun add_splits name thmP = Data.map (fn {records, sel_upd, equalities, extinjects, extsplit, splits, extfields, fieldext} => make_data records sel_upd equalities extinjects extsplit (Symtab.update_new (name, thmP) splits) extfields fieldext); val get_splits = Symtab.lookup o #splits o Data.get; (* parent/extension of named record *) val get_parent = (Option.join o Option.map #parent) oo (Symtab.lookup o #records o Data.get); val get_extension = Option.map #extension oo (Symtab.lookup o #records o Data.get); (* access 'extfields' *) fun add_extfields name fields = Data.map (fn {records, sel_upd, equalities, extinjects, extsplit, splits, extfields, fieldext} => make_data records sel_upd equalities extinjects extsplit splits (Symtab.update_new (name, fields) extfields) fieldext); val get_extfields = Symtab.lookup o #extfields o Data.get; fun get_extT_fields thy T = let val ((name, Ts), moreT) = dest_recT T; val recname = let val (nm :: _ :: rst) = rev (Long_Name.explode name) (* FIXME !? *) in Long_Name.implode (rev (nm :: rst)) end; val varifyT = varifyT (maxidx_of_typs (moreT :: Ts) + 1); val {records, extfields, ...} = Data.get thy; val (fields, (more, _)) = split_last (Symtab.lookup_list extfields name); val args = map varifyT (snd (#extension (the (Symtab.lookup records recname)))); val subst = fold (Sign.typ_match thy) (#1 (split_last args) ~~ #1 (split_last Ts)) Vartab.empty; val fields' = map (apsnd (Envir.norm_type subst o varifyT)) fields; in (fields', (more, moreT)) end; fun get_recT_fields thy T = let val (root_fields, (root_more, root_moreT)) = get_extT_fields thy T; val (rest_fields, rest_more) = if is_recT root_moreT then get_recT_fields thy root_moreT else ([], (root_more, root_moreT)); in (root_fields @ rest_fields, rest_more) end; (* access 'fieldext' *) fun add_fieldext extname_types fields = Data.map (fn {records, sel_upd, equalities, extinjects, extsplit, splits, extfields, fieldext} => let val fieldext' = fold (fn field => Symtab.update_new (field, extname_types)) fields fieldext; in make_data records sel_upd equalities extinjects extsplit splits extfields fieldext' end); val get_fieldext = Symtab.lookup o #fieldext o Data.get; (* parent records *) local fun add_parents _ NONE = I | add_parents thy (SOME (types, name)) = let fun err msg = error (msg ^ " parent record " ^ quote name); val {args, parent, ...} = (case get_info thy name of SOME info => info | NONE => err "Unknown"); val _ = if length types <> length args then err "Bad number of arguments for" else (); fun bad_inst ((x, S), T) = if Sign.of_sort thy (T, S) then NONE else SOME x val bads = map_filter bad_inst (args ~~ types); val _ = null bads orelse err ("Ill-sorted instantiation of " ^ commas bads ^ " in"); val inst = args ~~ types; val subst = Term.map_type_tfree (the o AList.lookup (op =) inst); val parent' = Option.map (apfst (map subst)) parent; in cons (name, inst) #> add_parents thy parent' end; in fun get_hierarchy thy (name, types) = add_parents thy (SOME (types, name)) []; fun get_parent_info thy parent = add_parents thy parent [] |> map (fn (name, inst) => let val subst = Term.map_type_tfree (the o AList.lookup (op =) inst); val {fields, extension, induct_scheme, ext_def, ...} = the_info thy name; val fields' = map (apsnd subst) fields; val extension' = apsnd (map subst) extension; in make_parent_info name fields' extension' ext_def induct_scheme end); end; (** concrete syntax for records **) (* parse translations *) local fun split_args (field :: fields) ((name, arg) :: fargs) = if can (unsuffix name) field then let val (args, rest) = split_args fields fargs in (arg :: args, rest) end else raise Fail ("expecting field " ^ quote field ^ " but got " ^ quote name) | split_args [] (fargs as (_ :: _)) = ([], fargs) | split_args (_ :: _) [] = raise Fail "expecting more fields" | split_args _ _ = ([], []); fun field_type_tr ((Const (\<^syntax_const>\_field_type\, _) $ Const (name, _) $ arg)) = (name, arg) | field_type_tr t = raise TERM ("field_type_tr", [t]); fun field_types_tr (Const (\<^syntax_const>\_field_types\, _) $ t $ u) = field_type_tr t :: field_types_tr u | field_types_tr t = [field_type_tr t]; fun record_field_types_tr more ctxt t = let val thy = Proof_Context.theory_of ctxt; fun err msg = raise TERM ("Error in record-type input: " ^ msg, [t]); fun mk_ext (fargs as (name, _) :: _) = (case get_fieldext thy (Proof_Context.intern_const ctxt name) of SOME (ext, alphas) => (case get_extfields thy ext of SOME fields => let val (fields', _) = split_last fields; val types = map snd fields'; val (args, rest) = split_args (map fst fields') fargs handle Fail msg => err msg; val argtypes = Syntax.check_typs ctxt (map Syntax_Phases.decode_typ args); val varifyT = varifyT (fold Term.maxidx_typ argtypes ~1 + 1); val vartypes = map varifyT types; val subst = Type.raw_matches (vartypes, argtypes) Vartab.empty handle Type.TYPE_MATCH => err "type is no proper record (extension)"; val alphas' = map (Syntax_Phases.term_of_typ ctxt o Envir.norm_type subst o varifyT) (#1 (split_last alphas)); val more' = mk_ext rest; in list_comb (Syntax.const (Lexicon.mark_type (suffix ext_typeN ext)), alphas' @ [more']) end | NONE => err ("no fields defined for " ^ quote ext)) | NONE => err (quote name ^ " is no proper field")) | mk_ext [] = more; in mk_ext (field_types_tr t) end; fun record_type_tr ctxt [t] = record_field_types_tr (Syntax.const \<^type_syntax>\unit\) ctxt t | record_type_tr _ ts = raise TERM ("record_type_tr", ts); fun record_type_scheme_tr ctxt [t, more] = record_field_types_tr more ctxt t | record_type_scheme_tr _ ts = raise TERM ("record_type_scheme_tr", ts); fun field_tr ((Const (\<^syntax_const>\_field\, _) $ Const (name, _) $ arg)) = (name, arg) | field_tr t = raise TERM ("field_tr", [t]); fun fields_tr (Const (\<^syntax_const>\_fields\, _) $ t $ u) = field_tr t :: fields_tr u | fields_tr t = [field_tr t]; fun record_fields_tr more ctxt t = let val thy = Proof_Context.theory_of ctxt; fun err msg = raise TERM ("Error in record input: " ^ msg, [t]); fun mk_ext (fargs as (name, _) :: _) = (case get_fieldext thy (Proof_Context.intern_const ctxt name) of SOME (ext, _) => (case get_extfields thy ext of SOME fields => let val (args, rest) = split_args (map fst (fst (split_last fields))) fargs handle Fail msg => err msg; val more' = mk_ext rest; in list_comb (Syntax.const (Lexicon.mark_const (ext ^ extN)), args @ [more']) end | NONE => err ("no fields defined for " ^ quote ext)) | NONE => err (quote name ^ " is no proper field")) | mk_ext [] = more; in mk_ext (fields_tr t) end; fun record_tr ctxt [t] = record_fields_tr (Syntax.const \<^const_syntax>\Unity\) ctxt t | record_tr _ ts = raise TERM ("record_tr", ts); fun record_scheme_tr ctxt [t, more] = record_fields_tr more ctxt t | record_scheme_tr _ ts = raise TERM ("record_scheme_tr", ts); fun field_update_tr (Const (\<^syntax_const>\_field_update\, _) $ Const (name, _) $ arg) = Syntax.const (suffix updateN name) $ Abs (Name.uu_, dummyT, arg) | field_update_tr t = raise TERM ("field_update_tr", [t]); fun field_updates_tr (Const (\<^syntax_const>\_field_updates\, _) $ t $ u) = field_update_tr t :: field_updates_tr u | field_updates_tr t = [field_update_tr t]; fun record_update_tr [t, u] = fold (curry op $) (field_updates_tr u) t | record_update_tr ts = raise TERM ("record_update_tr", ts); in val _ = Theory.setup (Sign.parse_translation [(\<^syntax_const>\_record_update\, K record_update_tr), (\<^syntax_const>\_record\, record_tr), (\<^syntax_const>\_record_scheme\, record_scheme_tr), (\<^syntax_const>\_record_type\, record_type_tr), (\<^syntax_const>\_record_type_scheme\, record_type_scheme_tr)]); end; (* print translations *) val type_abbr = Attrib.setup_config_bool \<^binding>\record_type_abbr\ (K true); val type_as_fields = Attrib.setup_config_bool \<^binding>\record_type_as_fields\ (K true); local (* FIXME early extern (!??) *) (* FIXME Syntax.free (??) *) fun field_type_tr' (c, t) = Syntax.const \<^syntax_const>\_field_type\ $ Syntax.const c $ t; fun field_types_tr' (t, u) = Syntax.const \<^syntax_const>\_field_types\ $ t $ u; fun record_type_tr' ctxt t = let val thy = Proof_Context.theory_of ctxt; val T = Syntax_Phases.decode_typ t; val varifyT = varifyT (Term.maxidx_of_typ T + 1); fun strip_fields T = (case T of Type (ext, args as _ :: _) => (case try (unsuffix ext_typeN) ext of SOME ext' => (case get_extfields thy ext' of SOME (fields as (x, _) :: _) => (case get_fieldext thy x of SOME (_, alphas) => (let val (f :: fs, _) = split_last fields; val fields' = apfst (Proof_Context.extern_const ctxt) f :: map (apfst Long_Name.base_name) fs; val (args', more) = split_last args; val alphavars = map varifyT (#1 (split_last alphas)); val subst = Type.raw_matches (alphavars, args') Vartab.empty; val fields'' = (map o apsnd) (Envir.norm_type subst o varifyT) fields'; in fields'' @ strip_fields more end handle Type.TYPE_MATCH => [("", T)]) | _ => [("", T)]) | _ => [("", T)]) | _ => [("", T)]) | _ => [("", T)]); val (fields, (_, moreT)) = split_last (strip_fields T); val _ = null fields andalso raise Match; val u = foldr1 field_types_tr' (map (field_type_tr' o apsnd (Syntax_Phases.term_of_typ ctxt)) fields); in if not (Config.get ctxt type_as_fields) orelse null fields then raise Match else if moreT = HOLogic.unitT then Syntax.const \<^syntax_const>\_record_type\ $ u else Syntax.const \<^syntax_const>\_record_type_scheme\ $ u $ Syntax_Phases.term_of_typ ctxt moreT end; (*try to reconstruct the record name type abbreviation from the (nested) extension types*) fun record_type_abbr_tr' abbr alphas zeta last_ext schemeT ctxt tm = let val T = Syntax_Phases.decode_typ tm; val varifyT = varifyT (maxidx_of_typ T + 1); fun mk_type_abbr subst name args = let val abbrT = Type (name, map (varifyT o TFree) args) in Syntax_Phases.term_of_typ ctxt (Envir.norm_type subst abbrT) end; fun match rT T = Type.raw_match (varifyT rT, T) Vartab.empty; in if Config.get ctxt type_abbr then (case last_extT T of SOME (name, _) => if name = last_ext then let val subst = match schemeT T in if HOLogic.is_unitT (Envir.norm_type subst (varifyT (TFree zeta))) then mk_type_abbr subst abbr alphas else mk_type_abbr subst (suffix schemeN abbr) (alphas @ [zeta]) end handle Type.TYPE_MATCH => record_type_tr' ctxt tm else raise Match (*give print translation of specialised record a chance*) | _ => raise Match) else record_type_tr' ctxt tm end; in fun record_ext_type_tr' name = let val ext_type_name = Lexicon.mark_type (suffix ext_typeN name); fun tr' ctxt ts = record_type_tr' ctxt (list_comb (Syntax.const ext_type_name, ts)); in (ext_type_name, tr') end; fun record_ext_type_abbr_tr' abbr alphas zeta last_ext schemeT name = let val ext_type_name = Lexicon.mark_type (suffix ext_typeN name); fun tr' ctxt ts = record_type_abbr_tr' abbr alphas zeta last_ext schemeT ctxt (list_comb (Syntax.const ext_type_name, ts)); in (ext_type_name, tr') end; end; local (* FIXME Syntax.free (??) *) fun field_tr' (c, t) = Syntax.const \<^syntax_const>\_field\ $ Syntax.const c $ t; fun fields_tr' (t, u) = Syntax.const \<^syntax_const>\_fields\ $ t $ u; fun record_tr' ctxt t = let val thy = Proof_Context.theory_of ctxt; fun strip_fields t = (case strip_comb t of (Const (ext, _), args as (_ :: _)) => (case try (Lexicon.unmark_const o unsuffix extN) ext of SOME ext' => (case get_extfields thy ext' of SOME fields => (let val (f :: fs, _) = split_last (map fst fields); val fields' = Proof_Context.extern_const ctxt f :: map Long_Name.base_name fs; val (args', more) = split_last args; in (fields' ~~ args') @ strip_fields more end handle ListPair.UnequalLengths => [("", t)]) | NONE => [("", t)]) | NONE => [("", t)]) | _ => [("", t)]); val (fields, (_, more)) = split_last (strip_fields t); val _ = null fields andalso raise Match; val u = foldr1 fields_tr' (map field_tr' fields); in (case more of Const (\<^const_syntax>\Unity\, _) => Syntax.const \<^syntax_const>\_record\ $ u | _ => Syntax.const \<^syntax_const>\_record_scheme\ $ u $ more) end; in fun record_ext_tr' name = let val ext_name = Lexicon.mark_const (name ^ extN); fun tr' ctxt ts = record_tr' ctxt (list_comb (Syntax.const ext_name, ts)); in (ext_name, tr') end; end; local fun dest_update ctxt c = (case try Lexicon.unmark_const c of SOME d => try (unsuffix updateN) (Proof_Context.extern_const ctxt d) | NONE => NONE); fun field_updates_tr' ctxt (tm as Const (c, _) $ k $ u) = (case dest_update ctxt c of SOME name => (case try Syntax_Trans.const_abs_tr' k of SOME t => apfst (cons (Syntax.const \<^syntax_const>\_field_update\ $ Syntax.free name $ t)) (field_updates_tr' ctxt u) | NONE => ([], tm)) | NONE => ([], tm)) | field_updates_tr' _ tm = ([], tm); fun record_update_tr' ctxt tm = (case field_updates_tr' ctxt tm of ([], _) => raise Match | (ts, u) => Syntax.const \<^syntax_const>\_record_update\ $ u $ foldr1 (fn (v, w) => Syntax.const \<^syntax_const>\_field_updates\ $ v $ w) (rev ts)); in fun field_update_tr' name = let val update_name = Lexicon.mark_const (name ^ updateN); fun tr' ctxt [t, u] = record_update_tr' ctxt (Syntax.const update_name $ t $ u) | tr' _ _ = raise Match; in (update_name, tr') end; end; (** record simprocs **) fun is_sel_upd_pair thy (Const (s, _)) (Const (u, t')) = (case get_updates thy u of SOME u_name => u_name = s | NONE => raise TERM ("is_sel_upd_pair: not update", [Const (u, t')])); fun mk_comp_id f = let val T = range_type (fastype_of f) in HOLogic.mk_comp (\<^Const>\id T\, f) end; fun get_upd_funs (upd $ _ $ t) = upd :: get_upd_funs t | get_upd_funs _ = []; fun get_accupd_simps ctxt term defset = let val thy = Proof_Context.theory_of ctxt; val (acc, [body]) = strip_comb term; val upd_funs = sort_distinct Term_Ord.fast_term_ord (get_upd_funs body); fun get_simp upd = let (* FIXME fresh "f" (!?) *) val T = domain_type (fastype_of upd); val lhs = HOLogic.mk_comp (acc, upd $ Free ("f", T)); val rhs = if is_sel_upd_pair thy acc upd then HOLogic.mk_comp (Free ("f", T), acc) else mk_comp_id acc; val prop = lhs === rhs; val othm = Goal.prove ctxt [] [] prop (fn {context = ctxt', ...} => simp_tac (put_simpset defset ctxt') 1 THEN REPEAT_DETERM (Iso_Tuple_Support.iso_tuple_intros_tac ctxt' 1) THEN TRY (simp_tac (put_simpset HOL_ss ctxt' addsimps @{thms id_apply id_o o_id}) 1)); val dest = if is_sel_upd_pair thy acc upd then @{thm o_eq_dest} else @{thm o_eq_id_dest}; in Drule.export_without_context (othm RS dest) end; in map get_simp upd_funs end; fun get_updupd_simp ctxt defset u u' comp = let (* FIXME fresh "f" (!?) *) val f = Free ("f", domain_type (fastype_of u)); val f' = Free ("f'", domain_type (fastype_of u')); val lhs = HOLogic.mk_comp (u $ f, u' $ f'); val rhs = if comp then u $ HOLogic.mk_comp (f, f') else HOLogic.mk_comp (u' $ f', u $ f); val prop = lhs === rhs; val othm = Goal.prove ctxt [] [] prop (fn {context = ctxt', ...} => simp_tac (put_simpset defset ctxt') 1 THEN REPEAT_DETERM (Iso_Tuple_Support.iso_tuple_intros_tac ctxt' 1) THEN TRY (simp_tac (put_simpset HOL_ss ctxt' addsimps @{thms id_apply}) 1)); val dest = if comp then @{thm o_eq_dest_lhs} else @{thm o_eq_dest}; in Drule.export_without_context (othm RS dest) end; fun gen_get_updupd_simps ctxt upd_funs defset = let val cname = fst o dest_Const; fun getswap u u' = get_updupd_simp ctxt defset u u' (cname u = cname u'); fun build_swaps_to_eq _ [] swaps = swaps | build_swaps_to_eq upd (u :: us) swaps = let val key = (cname u, cname upd); val newswaps = if Symreltab.defined swaps key then swaps else Symreltab.insert (K true) (key, getswap u upd) swaps; in if cname u = cname upd then newswaps else build_swaps_to_eq upd us newswaps end; fun swaps_needed [] _ _ swaps = map snd (Symreltab.dest swaps) | swaps_needed (u :: us) prev seen swaps = if Symtab.defined seen (cname u) then swaps_needed us prev seen (build_swaps_to_eq u prev swaps) else swaps_needed us (u :: prev) (Symtab.insert (K true) (cname u, ()) seen) swaps; in swaps_needed upd_funs [] Symtab.empty Symreltab.empty end; fun get_updupd_simps ctxt term defset = gen_get_updupd_simps ctxt (get_upd_funs term) defset; fun prove_unfold_defs thy upd_funs ex_simps ex_simprs prop = let val ctxt = Proof_Context.init_global thy; val defset = get_sel_upd_defs thy; val prop' = Envir.beta_eta_contract prop; val (lhs, _) = Logic.dest_equals (Logic.strip_assums_concl prop'); val (_, args) = strip_comb lhs; val simps = if null upd_funs then (if length args = 1 then get_accupd_simps else get_updupd_simps) ctxt lhs defset else gen_get_updupd_simps ctxt upd_funs defset in Goal.prove ctxt [] [] prop' (fn {context = ctxt', ...} => simp_tac (put_simpset HOL_basic_ss ctxt' addsimps (simps @ @{thms K_record_comp})) 1 THEN TRY (simp_tac (put_simpset HOL_basic_ss ctxt' addsimps ex_simps addsimprocs ex_simprs) 1)) end; local fun eq (s1: string) (s2: string) = (s1 = s2); fun has_field extfields f T = exists (fn (eN, _) => exists (eq f o fst) (Symtab.lookup_list extfields eN)) (dest_recTs T); fun K_skeleton n (T as Type (_, [_, kT])) (b as Bound i) (Abs (x, xT, t)) = if null (loose_bnos t) then ((n, kT), (Abs (x, xT, Bound (i + 1)))) else ((n, T), b) | K_skeleton n T b _ = ((n, T), b); in (* simproc *) (* Simplify selections of an record update: (1) S (S_update k r) = k (S r) (2) S (X_update k r) = S r The simproc skips multiple updates at once, eg: S (X_update x (Y_update y (S_update k r))) = k (S r) But be careful in (2) because of the extensibility of records. - If S is a more-selector we have to make sure that the update on component X does not affect the selected subrecord. - If X is a more-selector we have to make sure that S is not in the updated subrecord. *) val _ = Theory.setup (Named_Target.theory_map (Simplifier.define_simproc \<^binding>\select_update\ {lhss = [\<^term>\x::'a::{}\], proc = fn _ => fn ctxt => fn ct => let val thy = Proof_Context.theory_of ctxt; val t = Thm.term_of ct; in (case t of (sel as Const (s, Type (_, [_, rangeS]))) $ ((upd as Const (u, Type (_, [_, Type (_, [rT, _])]))) $ k $ r) => if is_selector thy s andalso is_some (get_updates thy u) then let val {sel_upd = {updates, ...}, extfields, ...} = Data.get thy; fun mk_eq_terms ((upd as Const (u, Type(_, [kT, _]))) $ k $ r) = (case Symtab.lookup updates u of NONE => NONE | SOME u_name => if u_name = s then (case mk_eq_terms r of NONE => let val rv = ("r", rT); val rb = Bound 0; val (kv, kb) = K_skeleton "k" kT (Bound 1) k; in SOME (upd $ kb $ rb, kb $ (sel $ rb), [kv, rv]) end | SOME (trm, trm', vars) => let val (kv, kb) = K_skeleton "k" kT (Bound (length vars)) k; in SOME (upd $ kb $ trm, kb $ trm', kv :: vars) end) else if has_field extfields u_name rangeS orelse has_field extfields s (domain_type kT) then NONE else (case mk_eq_terms r of SOME (trm, trm', vars) => let val (kv, kb) = K_skeleton "k" kT (Bound (length vars)) k in SOME (upd $ kb $ trm, trm', kv :: vars) end | NONE => let val rv = ("r", rT); val rb = Bound 0; val (kv, kb) = K_skeleton "k" kT (Bound 1) k; in SOME (upd $ kb $ rb, sel $ rb, [kv, rv]) end)) | mk_eq_terms _ = NONE; in (case mk_eq_terms (upd $ k $ r) of SOME (trm, trm', vars) => SOME (prove_unfold_defs thy [] [] [] (Logic.list_all (vars, Logic.mk_equals (sel $ trm, trm')))) | NONE => NONE) end else NONE | _ => NONE) end})); -val simproc_name = - Simplifier.check_simproc (Context.the_local_context ()) ("select_update", Position.none); -val simproc = Simplifier.the_simproc (Context.the_local_context ()) simproc_name; +val simproc = + #2 (Simplifier.check_simproc (Context.the_local_context ()) ("select_update", Position.none)); fun get_upd_acc_cong_thm upd acc thy ss = let val ctxt = Proof_Context.init_global thy; val prop = infer_instantiate ctxt [(("upd", 0), Thm.cterm_of ctxt upd), (("ac", 0), Thm.cterm_of ctxt acc)] updacc_cong_triv |> Thm.concl_of; in Goal.prove ctxt [] [] prop (fn {context = ctxt', ...} => simp_tac (put_simpset ss ctxt') 1 THEN REPEAT_DETERM (Iso_Tuple_Support.iso_tuple_intros_tac ctxt' 1) THEN TRY (resolve_tac ctxt' [updacc_cong_idI] 1)) end; fun sorted ord [] = true | sorted ord [x] = true | sorted ord (x::y::xs) = (case ord (x, y) of LESS => sorted ord (y::xs) | EQUAL => sorted ord (y::xs) | GREATER => false) fun insert_unique ord x [] = [x] | insert_unique ord x (y::ys) = (case ord (x, y) of LESS => (x::y::ys) | EQUAL => (x::ys) | GREATER => y :: insert_unique ord x ys) fun insert_unique_hd ord (x::xs) = x :: insert_unique ord x xs | insert_unique_hd ord xs = xs (* upd_simproc *) (*Simplify multiple updates: (1) "N_update y (M_update g (N_update x (M_update f r))) = (N_update (y o x) (M_update (g o f) r))" (2) "r(|M:= M r|) = r" In both cases "more" updates complicate matters: for this reason we omit considering further updates if doing so would introduce both a more update and an update to a field within it.*) val _ = Theory.setup (Named_Target.theory_map (Simplifier.define_simproc \<^binding>\update\ {lhss = [\<^term>\x::'a::{}\], proc = fn _ => fn ctxt => fn ct => let val thy = Proof_Context.theory_of ctxt; val t = Thm.term_of ct; (*We can use more-updators with other updators as long as none of the other updators go deeper than any more updator. min here is the depth of the deepest other updator, max the depth of the shallowest more updator.*) fun include_depth (dep, true) (min, max) = if min <= dep then SOME (min, if dep <= max orelse max = ~1 then dep else max) else NONE | include_depth (dep, false) (min, max) = if dep <= max orelse max = ~1 then SOME (if min <= dep then dep else min, max) else NONE; fun getupdseq (term as (upd as Const (u, _)) $ f $ tm) min max = (case get_update_details u thy of SOME (s, dep, ismore) => (case include_depth (dep, ismore) (min, max) of SOME (min', max') => let val (us, bs, _) = getupdseq tm min' max' in ((upd, s, f) :: us, bs, fastype_of term) end | NONE => ([], term, HOLogic.unitT)) | NONE => ([], term, HOLogic.unitT)) | getupdseq term _ _ = ([], term, HOLogic.unitT); val (upds, base, baseT) = getupdseq t 0 ~1; val orig_upds = map_index (fn (i, (x, y, z)) => (x, y, z, i)) upds val upd_ord = rev_order o fast_string_ord o apply2 #2 val (upds, commuted) = if not (null orig_upds) andalso Config.get ctxt sort_updates andalso not (sorted upd_ord orig_upds) then (sort upd_ord orig_upds, true) else (orig_upds, false) fun is_upd_noop s (Abs (n, T, Const (s', T') $ tm')) tm = if s = s' andalso null (loose_bnos tm') andalso subst_bound (HOLogic.unit, tm') = tm then (true, Abs (n, T, Const (s', T') $ Bound 1)) else (false, HOLogic.unit) | is_upd_noop _ _ _ = (false, HOLogic.unit); fun get_noop_simps (upd as Const _) (Abs (_, _, (acc as Const _) $ _)) = let val ss = get_sel_upd_defs thy; val uathm = get_upd_acc_cong_thm upd acc thy ss; in [Drule.export_without_context (uathm RS updacc_noopE), Drule.export_without_context (uathm RS updacc_noop_compE)] end; (*If f is constant then (f o g) = f. We know that K_skeleton only returns constant abstractions thus when we see an abstraction we can discard inner updates.*) fun add_upd (f as Abs _) _ = [f] | add_upd f fs = (f :: fs); (*mk_updterm returns (orig-term-skeleton-update list , simplified-skeleton, variables, duplicate-updates, simp-flag, noop-simps) where duplicate-updates is a table used to pass upward the list of update functions which can be composed into an update above them, simp-flag indicates whether any simplification was achieved, and noop-simps are used for eliminating case (2) defined above*) fun mk_updterm ((upd as Const (u, T), s, f, i) :: upds) above term = let val (lhs_upds, rhs, vars, dups, simp, noops) = mk_updterm upds (Symtab.update (u, ()) above) term; val (fvar, skelf) = K_skeleton (Long_Name.base_name s) (domain_type T) (Bound (length vars)) f; val (isnoop, skelf') = is_upd_noop s f term; val funT = domain_type T; fun mk_comp_local (f, f') = Const (\<^const_name>\Fun.comp\, funT --> funT --> funT) $ f $ f'; in if isnoop then ((upd $ skelf', i)::lhs_upds, rhs, vars, Symtab.update (u, []) dups, true, if Symtab.defined noops u then noops else Symtab.update (u, get_noop_simps upd skelf') noops) else if Symtab.defined above u then ((upd $ skelf, i)::lhs_upds, rhs, fvar :: vars, Symtab.map_default (u, []) (add_upd skelf) dups, true, noops) else (case Symtab.lookup dups u of SOME fs => ((upd $ skelf, i)::lhs_upds, upd $ foldr1 mk_comp_local (add_upd skelf fs) $ rhs, fvar :: vars, dups, true, noops) | NONE => ((upd $ skelf, i)::lhs_upds, upd $ skelf $ rhs, fvar :: vars, dups, simp, noops)) end | mk_updterm [] _ _ = ([], Bound 0, [("r", baseT)], Symtab.empty, false, Symtab.empty) | mk_updterm us _ _ = raise TERM ("mk_updterm match", map (fn (x, _, _, _) => x) us); val (lhs_upds, rhs, vars, _, simp, noops) = mk_updterm upds Symtab.empty base; val orig_order_lhs_upds = lhs_upds |> sort (rev_order o int_ord o apply2 snd) val lhs = Bound 0 |> fold (fn (upd, _) => fn s => upd $ s) orig_order_lhs_upds (* Note that the simplifier works bottom up. So all nested updates are already normalised, e.g. sorted. 'commuted' thus means that the outermost update has to be inserted at its place inside the sorted nested updates. The necessary swaps can be expressed via 'upd_funs' by replicating the outer update at the designated position: *) val upd_funs = (if commuted then insert_unique_hd upd_ord orig_upds else orig_upds) |> map #1 val noops' = maps snd (Symtab.dest noops); in if simp orelse commuted then SOME (prove_unfold_defs thy upd_funs noops' [simproc] (Logic.list_all (vars, Logic.mk_equals (lhs, rhs)))) else NONE end})); -val upd_simproc_name = - Simplifier.check_simproc (Context.the_local_context ()) ("update", Position.none); -val upd_simproc = Simplifier.the_simproc (Context.the_local_context ()) upd_simproc_name; +val upd_simproc = + #2 (Simplifier.check_simproc (Context.the_local_context ()) ("update", Position.none)); end; (* eq_simproc *) (*Look up the most specific record-equality. Note on efficiency: Testing equality of records boils down to the test of equality of all components. Therefore the complexity is: #components * complexity for single component. Especially if a record has a lot of components it may be better to split up the record first and do simplification on that (split_simp_tac). e.g. r(|lots of updates|) = x eq_simproc split_simp_tac Complexity: #components * #updates #updates *) val _ = Theory.setup (Named_Target.theory_map (Simplifier.define_simproc \<^binding>\eq\ {lhss = [\<^term>\r = s\], proc = fn _ => fn ctxt => fn ct => (case Thm.term_of ct of \<^Const_>\HOL.eq T for _ _\ => (case rec_id ~1 T of "" => NONE | name => (case get_equalities (Proof_Context.theory_of ctxt) name of NONE => NONE | SOME thm => SOME (thm RS @{thm Eq_TrueI}))) | _ => NONE)})); -val eq_simproc_name = Simplifier.check_simproc (Context.the_local_context ()) ("eq", Position.none); -val eq_simproc = Simplifier.the_simproc (Context.the_local_context ()) eq_simproc_name; +val eq_simproc = + #2 (Simplifier.check_simproc (Context.the_local_context ()) ("eq", Position.none)); + (* split_simproc *) (*Split quantified occurrences of records, for which P holds. P can peek on the subterm starting at the quantified occurrence of the record (including the quantifier): P t = 0: do not split P t = ~1: completely split P t > 0: split up to given bound of record extensions.*) fun split_simproc P = Simplifier.make_simproc \<^context> "record_split" {lhss = [\<^term>\x::'a::{}\], proc = fn _ => fn ctxt => fn ct => (case Thm.term_of ct of Const (quantifier, Type (_, [Type (_, [T, _]), _])) $ _ => if quantifier = \<^const_name>\Pure.all\ orelse quantifier = \<^const_name>\All\ orelse quantifier = \<^const_name>\Ex\ then (case rec_id ~1 T of "" => NONE | _ => let val split = P (Thm.term_of ct) in if split <> 0 then (case get_splits (Proof_Context.theory_of ctxt) (rec_id split T) of NONE => NONE | SOME (all_thm, All_thm, Ex_thm, _) => SOME (case quantifier of \<^const_name>\Pure.all\ => all_thm | \<^const_name>\All\ => All_thm RS @{thm eq_reflection} | \<^const_name>\Ex\ => Ex_thm RS @{thm eq_reflection} | _ => raise Fail "split_simproc")) else NONE end) else NONE | _ => NONE)}; val _ = Theory.setup (Named_Target.theory_map (Simplifier.define_simproc \<^binding>\ex_sel_eq\ {lhss = [\<^term>\Ex t\], proc = fn _ => fn ctxt => fn ct => let val thy = Proof_Context.theory_of ctxt; val t = Thm.term_of ct; fun mkeq (lr, T, (sel, Tsel), x) i = if is_selector thy sel then let val x' = if not (Term.is_dependent x) then Free ("x" ^ string_of_int i, range_type Tsel) else raise TERM ("", [x]); val sel' = Const (sel, Tsel) $ Bound 0; val (l, r) = if lr then (sel', x') else (x', sel'); in \<^Const>\HOL.eq T for l r\ end else raise TERM ("", [Const (sel, Tsel)]); fun dest_sel_eq (\<^Const_>\HOL.eq T\ $ (Const (sel, Tsel) $ Bound 0) $ X) = (true, T, (sel, Tsel), X) | dest_sel_eq (\<^Const_>\HOL.eq T\ $ X $ (Const (sel, Tsel) $ Bound 0)) = (false, T, (sel, Tsel), X) | dest_sel_eq _ = raise TERM ("", []); in (case t of \<^Const_>\Ex T for \Abs (s, _, t)\\ => (let val eq = mkeq (dest_sel_eq t) 0; val prop = Logic.list_all ([("r", T)], Logic.mk_equals (\<^Const>\Ex T for \Abs (s, T, eq)\\, \<^Const>\True\)); in SOME (Goal.prove_sorry_global thy [] [] prop (fn {context = ctxt', ...} => simp_tac (put_simpset (get_simpset thy) ctxt' addsimps @{thms simp_thms} addsimprocs [split_simproc (K ~1)]) 1)) end handle TERM _ => NONE) | _ => NONE) end})); -val ex_sel_eq_simproc_name = - Simplifier.check_simproc (Context.the_local_context ()) ("ex_sel_eq", Position.none); -val ex_sel_eq_simproc = Simplifier.the_simproc (Context.the_local_context ()) ex_sel_eq_simproc_name; +val ex_sel_eq_simproc = + #2 (Simplifier.check_simproc (Context.the_local_context ()) ("ex_sel_eq", Position.none)); + val _ = Theory.setup (map_theory_simpset (fn ctxt => ctxt delsimprocs [ex_sel_eq_simproc])); (* split_simp_tac *) (*Split (and simplify) all records in the goal for which P holds. For quantified occurrences of a record P can peek on the whole subterm (including the quantifier); for free variables P can only peek on the variable itself. P t = 0: do not split P t = ~1: completely split P t > 0: split up to given bound of record extensions.*) fun split_simp_tac ctxt thms P = CSUBGOAL (fn (cgoal, i) => let val thy = Proof_Context.theory_of ctxt; val goal = Thm.term_of cgoal; val frees = filter (is_recT o #2) (Term.add_frees goal []); val has_rec = exists_Const (fn (s, Type (_, [Type (_, [T, _]), _])) => (s = \<^const_name>\Pure.all\ orelse s = \<^const_name>\All\ orelse s = \<^const_name>\Ex\) andalso is_recT T | _ => false); fun mk_split_free_tac free induct_thm i = let val _ $ (_ $ Var (r, _)) = Thm.concl_of induct_thm; val thm = infer_instantiate ctxt [(r, Thm.cterm_of ctxt free)] induct_thm; in simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms induct_atomize}) i THEN resolve_tac ctxt [thm] i THEN simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms induct_rulify}) i end; val split_frees_tacs = frees |> map_filter (fn (x, T) => (case rec_id ~1 T of "" => NONE | _ => let val free = Free (x, T); val split = P free; in if split <> 0 then (case get_splits thy (rec_id split T) of NONE => NONE | SOME (_, _, _, induct_thm) => SOME (mk_split_free_tac free induct_thm i)) else NONE end)); val simprocs = if has_rec goal then [split_simproc P] else []; val thms' = @{thms o_apply K_record_comp} @ thms; in EVERY split_frees_tacs THEN full_simp_tac (put_simpset (get_simpset thy) ctxt addsimps thms' addsimprocs simprocs) i end); (* split_tac *) (*Split all records in the goal, which are quantified by !! or ALL.*) fun split_tac ctxt = CSUBGOAL (fn (cgoal, i) => let val goal = Thm.term_of cgoal; val has_rec = exists_Const (fn (s, Type (_, [Type (_, [T, _]), _])) => (s = \<^const_name>\Pure.all\ orelse s = \<^const_name>\All\) andalso is_recT T | _ => false); fun is_all (Const (\<^const_name>\Pure.all\, _) $ _) = ~1 | is_all (Const (\<^const_name>\All\, _) $ _) = ~1 | is_all _ = 0; in if has_rec goal then full_simp_tac (put_simpset HOL_basic_ss ctxt addsimprocs [split_simproc is_all]) i else no_tac end); (* wrapper *) val split_name = "record_split_tac"; val split_wrapper = (split_name, fn ctxt => fn tac => split_tac ctxt ORELSE' tac); (** theory extender interface **) (* attributes *) val case_names_fields = Rule_Cases.case_names ["fields"]; fun induct_type_global name = [case_names_fields, Induct.induct_type name]; fun cases_type_global name = [case_names_fields, Induct.cases_type name]; (* tactics *) (*Do case analysis / induction according to rule on last parameter of ith subgoal (or on s if there are no parameters). Instatiation of record variable (and predicate) in rule is calculated to avoid problems with higher order unification.*) fun try_param_tac ctxt s rule = CSUBGOAL (fn (cgoal, i) => let val g = Thm.term_of cgoal; val params = Logic.strip_params g; val concl = HOLogic.dest_Trueprop (Logic.strip_assums_concl g); val rule' = Thm.lift_rule cgoal rule; val (P, ys) = strip_comb (HOLogic.dest_Trueprop (Logic.strip_assums_concl (Thm.prop_of rule'))); (*ca indicates if rule is a case analysis or induction rule*) val (x, ca) = (case rev (drop (length params) ys) of [] => (head_of (fst (HOLogic.dest_eq (HOLogic.dest_Trueprop (hd (rev (Logic.strip_assums_hyp (hd (Thm.prems_of rule')))))))), true) | [x] => (head_of x, false)); val rule'' = infer_instantiate ctxt (map (apsnd (Thm.cterm_of ctxt)) (case rev params of [] => (case AList.lookup (op =) (Term.add_frees g []) s of NONE => error "try_param_tac: no such variable" | SOME T => [(#1 (dest_Var P), if ca then concl else lambda (Free (s, T)) concl), (#1 (dest_Var x), Free (s, T))]) | (_, T) :: _ => [(#1 (dest_Var P), fold_rev Term.abs params (if ca then concl else incr_boundvars 1 (Abs (s, T, concl)))), (#1 (dest_Var x), fold_rev Term.abs params (Bound 0))])) rule'; in compose_tac ctxt (false, rule'', Thm.nprems_of rule) i end); fun extension_definition overloaded name fields alphas zeta moreT more vars thy = let val ctxt = Proof_Context.init_global thy; val base_name = Long_Name.base_name name; val fieldTs = map snd fields; val fields_moreTs = fieldTs @ [moreT]; val alphas_zeta = alphas @ [zeta]; val ext_binding = Binding.name (suffix extN base_name); val ext_name = suffix extN name; val ext_tyco = suffix ext_typeN name; val extT = Type (ext_tyco, map TFree alphas_zeta); val ext_type = fields_moreTs ---> extT; (* the tree of new types that will back the record extension *) val mktreeV = Balanced_Tree.make Iso_Tuple_Support.mk_cons_tuple; fun mk_iso_tuple (left, right) (thy, i) = let val suff = if i = 0 then ext_typeN else inner_typeN ^ string_of_int i; val ((_, cons), thy') = thy |> Iso_Tuple_Support.add_iso_tuple_type overloaded (Binding.suffix_name suff (Binding.name base_name), alphas_zeta) (fastype_of left, fastype_of right); in (cons $ left $ right, (thy', i + 1)) end; (*trying to create a 1-element iso_tuple will fail, and is pointless anyway*) fun mk_even_iso_tuple [arg] = pair arg | mk_even_iso_tuple args = mk_iso_tuple (Iso_Tuple_Support.dest_cons_tuple (mktreeV args)); fun build_meta_tree_type i thy vars more = let val len = length vars in if len < 1 then raise TYPE ("meta_tree_type args too short", [], vars) else if len > 16 then let fun group16 [] = [] | group16 xs = take 16 xs :: group16 (drop 16 xs); val vars' = group16 vars; val (composites, (thy', i')) = fold_map mk_even_iso_tuple vars' (thy, i); in build_meta_tree_type i' thy' composites more end else let val (term, (thy', _)) = mk_iso_tuple (mktreeV vars, more) (thy, 0) in (term, thy') end end; val _ = timing_msg ctxt "record extension preparing definitions"; (* 1st stage part 1: introduce the tree of new types *) val (ext_body, typ_thy) = timeit_msg ctxt "record extension nested type def:" (fn () => build_meta_tree_type 1 thy vars more); (* prepare declarations and definitions *) (* 1st stage part 2: define the ext constant *) fun mk_ext args = list_comb (Const (ext_name, ext_type), args); val ext_spec = Logic.mk_equals (mk_ext (vars @ [more]), ext_body); val ([ext_def], defs_thy) = timeit_msg ctxt "record extension constructor def:" (fn () => typ_thy |> Sign.declare_const_global ((ext_binding, ext_type), NoSyn) |> snd |> Global_Theory.add_defs false [((Thm.def_binding ext_binding, ext_spec), [])]); val defs_ctxt = Proof_Context.init_global defs_thy; (* prepare propositions *) val _ = timing_msg ctxt "record extension preparing propositions"; val vars_more = vars @ [more]; val variants = map (fn Free (x, _) => x) vars_more; val ext = mk_ext vars_more; val s = Free (rN, extT); val P = Free (singleton (Name.variant_list variants) "P", extT --> HOLogic.boolT); val inject_prop = (* FIXME local x x' *) let val vars_more' = map (fn (Free (x, T)) => Free (x ^ "'", T)) vars_more in HOLogic.mk_conj (HOLogic.eq_const extT $ mk_ext vars_more $ mk_ext vars_more', \<^term>\True\) === foldr1 HOLogic.mk_conj (map HOLogic.mk_eq (vars_more ~~ vars_more') @ [\<^term>\True\]) end; val induct_prop = (fold_rev Logic.all vars_more (Trueprop (P $ ext)), Trueprop (P $ s)); val split_meta_prop = (* FIXME local P *) let val P = Free (singleton (Name.variant_list variants) "P", extT --> propT) in Logic.mk_equals (Logic.all s (P $ s), fold_rev Logic.all vars_more (P $ ext)) end; val inject = timeit_msg ctxt "record extension inject proof:" (fn () => simplify (put_simpset HOL_ss defs_ctxt) (Goal.prove_sorry_global defs_thy [] [] inject_prop (fn {context = ctxt', ...} => simp_tac (put_simpset HOL_basic_ss ctxt' addsimps [ext_def]) 1 THEN REPEAT_DETERM (resolve_tac ctxt' @{thms refl_conj_eq} 1 ORELSE Iso_Tuple_Support.iso_tuple_intros_tac ctxt' 1 ORELSE resolve_tac ctxt' [refl] 1)))); (*We need a surjection property r = (| f = f r, g = g r ... |) to prove other theorems. We haven't given names to the accessors f, g etc yet however, so we generate an ext structure with free variables as all arguments and allow the introduction tactic to operate on it as far as it can. We then use Drule.export_without_context to convert the free variables into unifiable variables and unify them with (roughly) the definition of the accessor.*) val surject = timeit_msg ctxt "record extension surjective proof:" (fn () => let val start = infer_instantiate defs_ctxt [(("y", 0), Thm.cterm_of defs_ctxt ext)] surject_assist_idE; val tactic1 = simp_tac (put_simpset HOL_basic_ss defs_ctxt addsimps [ext_def]) 1 THEN REPEAT_ALL_NEW (Iso_Tuple_Support.iso_tuple_intros_tac defs_ctxt) 1; val tactic2 = REPEAT (resolve_tac defs_ctxt [surject_assistI] 1 THEN resolve_tac defs_ctxt [refl] 1); val [halfway] = Seq.list_of (tactic1 start); val [surject] = Seq.list_of (tactic2 (Drule.export_without_context halfway)); in surject end); val split_meta = timeit_msg ctxt "record extension split_meta proof:" (fn () => Goal.prove_sorry_global defs_thy [] [] split_meta_prop (fn {context = ctxt', ...} => EVERY1 [resolve_tac ctxt' @{thms equal_intr_rule}, Goal.norm_hhf_tac ctxt', eresolve_tac ctxt' @{thms meta_allE}, assume_tac ctxt', resolve_tac ctxt' [@{thm prop_subst} OF [surject]], REPEAT o eresolve_tac ctxt' @{thms meta_allE}, assume_tac ctxt'])); val induct = timeit_msg ctxt "record extension induct proof:" (fn () => let val (assm, concl) = induct_prop in Goal.prove_sorry_global defs_thy [] [assm] concl (fn {context = ctxt', prems, ...} => cut_tac (split_meta RS Drule.equal_elim_rule2) 1 THEN resolve_tac ctxt' prems 2 THEN asm_simp_tac (put_simpset HOL_ss ctxt') 1) end); val ([(_, [induct']), (_, [inject']), (_, [surjective']), (_, [split_meta'])], thm_thy) = defs_thy |> Global_Theory.note_thmss "" [((Binding.name "ext_induct", []), [([induct], [])]), ((Binding.name "ext_inject", []), [([inject], [])]), ((Binding.name "ext_surjective", []), [([surject], [])]), ((Binding.name "ext_split", []), [([split_meta], [])])]; in (((ext_name, ext_type), (ext_tyco, alphas_zeta), extT, induct', inject', surjective', split_meta', ext_def), thm_thy) end; fun chunks [] [] = [] | chunks [] xs = [xs] | chunks (l :: ls) xs = take l xs :: chunks ls (drop l xs); fun chop_last [] = error "chop_last: list should not be empty" | chop_last [x] = ([], x) | chop_last (x :: xs) = let val (tl, l) = chop_last xs in (x :: tl, l) end; fun subst_last _ [] = error "subst_last: list should not be empty" | subst_last s [_] = [s] | subst_last s (x :: xs) = x :: subst_last s xs; (* mk_recordT *) (*build up the record type from the current extension tpye extT and a list of parent extensions, starting with the root of the record hierarchy*) fun mk_recordT extT = fold_rev (fn (parent, Ts) => fn T => Type (parent, subst_last T Ts)) extT; (* code generation *) fun mk_random_eq tyco vs extN Ts = let (* FIXME local i etc. *) val size = \<^term>\i::natural\; fun termifyT T = HOLogic.mk_prodT (T, \<^typ>\unit \ term\); val T = Type (tyco, map TFree vs); val Tm = termifyT T; val params = Name.invent_names Name.context "x" Ts; val lhs = HOLogic.mk_random T size; val tc = HOLogic.mk_return Tm \<^typ>\Random.seed\ (HOLogic.mk_valtermify_app extN params T); val rhs = HOLogic.mk_ST (map (fn (v, T') => ((HOLogic.mk_random T' size, \<^typ>\Random.seed\), SOME (v, termifyT T'))) params) tc \<^typ>\Random.seed\ (SOME Tm, \<^typ>\Random.seed\); in (lhs, rhs) end fun mk_full_exhaustive_eq tyco vs extN Ts = let (* FIXME local i etc. *) val size = \<^term>\i::natural\; fun termifyT T = HOLogic.mk_prodT (T, \<^typ>\unit \ term\); val T = Type (tyco, map TFree vs); val test_function = Free ("f", termifyT T --> \<^typ>\(bool \ term list) option\); val params = Name.invent_names Name.context "x" Ts; fun mk_full_exhaustive U = \<^Const>\full_exhaustive_class.full_exhaustive U\; val lhs = mk_full_exhaustive T $ test_function $ size; val tc = test_function $ (HOLogic.mk_valtermify_app extN params T); val rhs = fold_rev (fn (v, U) => fn cont => mk_full_exhaustive U $ (lambda (Free (v, termifyT U)) cont) $ size) params tc; in (lhs, rhs) end; fun instantiate_sort_record (sort, mk_eq) tyco vs extN Ts thy = let val eq = HOLogic.mk_Trueprop (HOLogic.mk_eq (mk_eq tyco vs extN Ts)); in thy |> Class.instantiation ([tyco], vs, sort) |> `(fn lthy => Syntax.check_term lthy eq) |-> (fn eq => Specification.definition NONE [] [] ((Binding.concealed Binding.empty, []), eq)) |> snd |> Class.prove_instantiation_exit (fn ctxt => Class.intro_classes_tac ctxt []) end; fun ensure_sort_record (sort, mk_eq) ext_tyco vs extN Ts thy = let val algebra = Sign.classes_of thy; val has_inst = Sorts.has_instance algebra ext_tyco sort; in if has_inst then thy else (case Quickcheck_Common.perhaps_constrain thy (map (rpair sort) Ts) vs of SOME constrain => instantiate_sort_record (sort, mk_eq) ext_tyco (map constrain vs) extN ((map o map_atyps) (fn TFree v => TFree (constrain v)) Ts) thy | NONE => thy) end; fun add_code ext_tyco vs extT ext simps inject thy = if Config.get_global thy codegen then let val eq = HOLogic.mk_Trueprop (HOLogic.mk_eq (\<^Const>\HOL.equal extT\, \<^Const>\HOL.eq extT\)); fun tac ctxt eq_def = Class.intro_classes_tac ctxt [] THEN rewrite_goals_tac ctxt [Simpdata.mk_eq eq_def] THEN ALLGOALS (resolve_tac ctxt @{thms refl}); fun mk_eq ctxt eq_def = rewrite_rule ctxt [Axclass.unoverload ctxt (Thm.symmetric (Simpdata.mk_eq eq_def))] inject; fun mk_eq_refl ctxt = \<^instantiate>\'a = \Thm.ctyp_of ctxt (Logic.varifyT_global extT)\ in lemma (schematic) \equal_class.equal x x \ True\ by (rule equal_refl)\ |> Axclass.unoverload ctxt; val ensure_random_record = ensure_sort_record (\<^sort>\random\, mk_random_eq); val ensure_exhaustive_record = ensure_sort_record (\<^sort>\full_exhaustive\, mk_full_exhaustive_eq); fun add_code eq_def thy = let val ctxt = Proof_Context.init_global thy; in thy |> Code.declare_default_eqns_global [(mk_eq ctxt eq_def, true), (mk_eq_refl ctxt, false)] end; in thy |> Code.declare_datatype_global [ext] |> Code.declare_default_eqns_global (map (rpair true) simps) |> Class.instantiation ([ext_tyco], vs, [HOLogic.class_equal]) |> `(fn lthy => Syntax.check_term lthy eq) |-> (fn eq => Specification.definition NONE [] [] (Binding.empty_atts, eq)) |-> (fn (_, (_, eq_def)) => Class.prove_instantiation_exit_result Morphism.thm tac eq_def) |-> add_code |> ensure_random_record ext_tyco vs (fst ext) (binder_types (snd ext)) |> ensure_exhaustive_record ext_tyco vs (fst ext) (binder_types (snd ext)) end else thy; fun add_ctr_sugar ctr exhaust inject sel_thms = Ctr_Sugar.default_register_ctr_sugar_global (K true) {kind = Ctr_Sugar.Record, T = body_type (fastype_of ctr), ctrs = [ctr], casex = Term.dummy, discs = [], selss = [], exhaust = exhaust, nchotomy = Drule.dummy_thm, injects = [inject], distincts = [], case_thms = [], case_cong = Drule.dummy_thm, case_cong_weak = Drule.dummy_thm, case_distribs = [], split = Drule.dummy_thm, split_asm = Drule.dummy_thm, disc_defs = [], disc_thmss = [], discIs = [], disc_eq_cases = [], sel_defs = [], sel_thmss = [sel_thms], distinct_discsss = [], exhaust_discs = [], exhaust_sels = [], collapses = [], expands = [], split_sels = [], split_sel_asms = [], case_eq_ifs = []}; fun lhs_of_equation \<^Const_>\Pure.eq _ for t _\ = t | lhs_of_equation \<^Const_>\Trueprop for \<^Const_>\HOL.eq _ for t _\\ = t; fun add_spec_rule rule = let val head = head_of (lhs_of_equation (Thm.prop_of rule)) in Spec_Rules.add_global Binding.empty Spec_Rules.equational [head] [rule] end; (* definition *) fun definition overloaded (alphas, binding) parent (parents: parent_info list) raw_fields thy0 = let val ctxt0 = Proof_Context.init_global thy0; val prefix = Binding.name_of binding; val name = Sign.full_name thy0 binding; val full = Sign.full_name_path thy0 prefix; val bfields = map (fn (x, T, _) => (x, T)) raw_fields; val field_syntax = map #3 raw_fields; val parent_fields = maps #fields parents; val parent_chunks = map (length o #fields) parents; val parent_names = map fst parent_fields; val parent_types = map snd parent_fields; val parent_fields_len = length parent_fields; val parent_variants = Name.variant_list [moreN, rN, rN ^ "'", wN] (map Long_Name.base_name parent_names); val parent_vars = map2 (curry Free) parent_variants parent_types; val parent_len = length parents; val fields = map (apfst full) bfields; val names = map fst fields; val types = map snd fields; val alphas_fields = fold Term.add_tfreesT types []; val alphas_ext = inter (op =) alphas_fields alphas; val len = length fields; val variants = Name.variant_list (moreN :: rN :: (rN ^ "'") :: wN :: parent_variants) (map (Binding.name_of o fst) bfields); val vars = map2 (curry Free) variants types; val named_vars = names ~~ vars; val idxms = 0 upto len; val all_fields = parent_fields @ fields; val all_types = parent_types @ types; val all_variants = parent_variants @ variants; val all_vars = parent_vars @ vars; val all_named_vars = (parent_names ~~ parent_vars) @ named_vars; val zeta = (singleton (Name.variant_list (map #1 alphas)) "'z", \<^sort>\type\); val moreT = TFree zeta; val more = Free (moreN, moreT); val full_moreN = full (Binding.name moreN); val bfields_more = bfields @ [(Binding.name moreN, moreT)]; val fields_more = fields @ [(full_moreN, moreT)]; val named_vars_more = named_vars @ [(full_moreN, more)]; val all_vars_more = all_vars @ [more]; val all_named_vars_more = all_named_vars @ [(full_moreN, more)]; (* 1st stage: ext_thy *) val extension_name = full binding; val ((ext, (ext_tyco, vs), extT, ext_induct, ext_inject, ext_surjective, ext_split, ext_def), ext_thy) = thy0 |> Sign.qualified_path false binding |> extension_definition overloaded extension_name fields alphas_ext zeta moreT more vars; val ext_ctxt = Proof_Context.init_global ext_thy; val _ = timing_msg ext_ctxt "record preparing definitions"; val Type extension_scheme = extT; val extension_name = unsuffix ext_typeN (fst extension_scheme); val extension = let val (n, Ts) = extension_scheme in (n, subst_last HOLogic.unitT Ts) end; val extension_names = map (unsuffix ext_typeN o fst o #extension) parents @ [extension_name]; val extension_id = implode extension_names; fun rec_schemeT n = mk_recordT (map #extension (drop n parents)) extT; val rec_schemeT0 = rec_schemeT 0; fun recT n = let val (c, Ts) = extension in mk_recordT (map #extension (drop n parents)) (Type (c, subst_last HOLogic.unitT Ts)) end; val recT0 = recT 0; fun mk_rec args n = let val (args', more) = chop_last args; fun mk_ext' ((name, T), args) more = mk_ext (name, T) (args @ [more]); fun build Ts = fold_rev mk_ext' (drop n ((extension_names ~~ Ts) ~~ chunks parent_chunks args')) more; in if more = HOLogic.unit then build (map_range recT (parent_len + 1)) else build (map_range rec_schemeT (parent_len + 1)) end; val r_rec0 = mk_rec all_vars_more 0; val r_rec_unit0 = mk_rec (all_vars @ [HOLogic.unit]) 0; fun r n = Free (rN, rec_schemeT n); val r0 = r 0; fun r_unit n = Free (rN, recT n); val r_unit0 = r_unit 0; (* print translations *) val record_ext_type_abbr_tr's = let val trname = hd extension_names; val last_ext = unsuffix ext_typeN (fst extension); in [record_ext_type_abbr_tr' name alphas zeta last_ext rec_schemeT0 trname] end; val record_ext_type_tr's = let (*avoid conflict with record_type_abbr_tr's*) val trnames = if parent_len > 0 then [extension_name] else []; in map record_ext_type_tr' trnames end; val print_translation = map field_update_tr' (full_moreN :: names) @ [record_ext_tr' extension_name] @ record_ext_type_tr's @ record_ext_type_abbr_tr's; (* prepare declarations *) val sel_decls = map (mk_selC rec_schemeT0 o apfst Binding.name_of) bfields_more; val upd_decls = map (mk_updC updateN rec_schemeT0 o apfst Binding.name_of) bfields_more; val make_decl = (makeN, all_types ---> recT0); val fields_decl = (fields_selN, types ---> Type extension); val extend_decl = (extendN, recT0 --> moreT --> rec_schemeT0); val truncate_decl = (truncateN, rec_schemeT0 --> recT0); (* prepare definitions *) val ext_defs = ext_def :: map #ext_def parents; (*Theorems from the iso_tuple intros. By unfolding ext_defs from r_rec0 we create a tree of constructor calls (many of them Pair, but others as well). The introduction rules for update_accessor_eq_assist can unify two different ways on these constructors. If we take the complete result sequence of running a the introduction tactic, we get one theorem for each upd/acc pair, from which we can derive the bodies of our selector and updator and their convs.*) val (accessor_thms, updator_thms, upd_acc_cong_assists) = timeit_msg ext_ctxt "record getting tree access/updates:" (fn () => let val r_rec0_Vars = let (*pick variable indices of 1 to avoid possible variable collisions with existing variables in updacc_eq_triv*) fun to_Var (Free (c, T)) = Var ((c, 1), T); in mk_rec (map to_Var all_vars_more) 0 end; val init_thm = infer_instantiate ext_ctxt [(("v", 0), Thm.cterm_of ext_ctxt r_rec0), (("v'", 0), Thm.cterm_of ext_ctxt r_rec0_Vars)] updacc_eq_triv; val terminal = resolve_tac ext_ctxt [updacc_eq_idI] 1 THEN resolve_tac ext_ctxt [refl] 1; val tactic = simp_tac (put_simpset HOL_basic_ss ext_ctxt addsimps ext_defs) 1 THEN REPEAT (Iso_Tuple_Support.iso_tuple_intros_tac ext_ctxt 1 ORELSE terminal); val updaccs = Seq.list_of (tactic init_thm); in (updaccs RL [updacc_accessor_eqE], updaccs RL [updacc_updator_eqE], updaccs RL [updacc_cong_from_eq]) end); fun lastN xs = drop parent_fields_len xs; (*selectors*) fun mk_sel_spec ((c, T), thm) = let val (acc $ arg, _) = HOLogic.dest_eq (HOLogic.dest_Trueprop (Envir.beta_eta_contract (Thm.concl_of thm))); val _ = if arg aconv r_rec0 then () else raise TERM ("mk_sel_spec: different arg", [arg]); in Const (mk_selC rec_schemeT0 (c, T)) :== acc end; val sel_specs = map mk_sel_spec (fields_more ~~ lastN accessor_thms); (*updates*) fun mk_upd_spec ((c, T), thm) = let val (upd $ _ $ arg, _) = HOLogic.dest_eq (HOLogic.dest_Trueprop (Envir.beta_eta_contract (Thm.concl_of thm))); val _ = if arg aconv r_rec0 then () else raise TERM ("mk_sel_spec: different arg", [arg]); in Const (mk_updC updateN rec_schemeT0 (c, T)) :== upd end; val upd_specs = map mk_upd_spec (fields_more ~~ lastN updator_thms); (*derived operations*) val make_spec = list_comb (Const (full (Binding.name makeN), all_types ---> recT0), all_vars) :== mk_rec (all_vars @ [HOLogic.unit]) 0; val fields_spec = list_comb (Const (full (Binding.name fields_selN), types ---> Type extension), vars) :== mk_rec (all_vars @ [HOLogic.unit]) parent_len; val extend_spec = Const (full (Binding.name extendN), recT0 --> moreT --> rec_schemeT0) $ r_unit0 $ more :== mk_rec ((map (mk_sel r_unit0) all_fields) @ [more]) 0; val truncate_spec = Const (full (Binding.name truncateN), rec_schemeT0 --> recT0) $ r0 :== mk_rec ((map (mk_sel r0) all_fields) @ [HOLogic.unit]) 0; (* 2st stage: defs_thy *) val (((sel_defs, upd_defs), derived_defs), defs_thy) = timeit_msg ext_ctxt "record trfuns/tyabbrs/selectors/updates/make/fields/extend/truncate defs:" (fn () => ext_thy |> Sign.print_translation print_translation |> Sign.restore_naming thy0 |> Typedecl.abbrev_global (binding, map #1 alphas, NoSyn) recT0 |> snd |> Typedecl.abbrev_global (Binding.suffix_name schemeN binding, map #1 (alphas @ [zeta]), NoSyn) rec_schemeT0 |> snd |> Sign.qualified_path false binding |> fold (fn ((x, T), mx) => snd o Sign.declare_const_global ((Binding.name x, T), mx)) (sel_decls ~~ (field_syntax @ [NoSyn])) |> fold (fn (x, T) => snd o Sign.declare_const_global ((Binding.name x, T), NoSyn)) (upd_decls @ [make_decl, fields_decl, extend_decl, truncate_decl]) |> (Global_Theory.add_defs false o map (Thm.no_attributes o apfst (Binding.concealed o Binding.name))) sel_specs ||>> (Global_Theory.add_defs false o map (Thm.no_attributes o apfst (Binding.concealed o Binding.name))) upd_specs ||>> (Global_Theory.add_defs false o map (Thm.no_attributes o apfst (Binding.concealed o Binding.name))) [make_spec, fields_spec, extend_spec, truncate_spec]); val defs_ctxt = Proof_Context.init_global defs_thy; (* prepare propositions *) val _ = timing_msg defs_ctxt "record preparing propositions"; val P = Free (singleton (Name.variant_list all_variants) "P", rec_schemeT0 --> HOLogic.boolT); val C = Free (singleton (Name.variant_list all_variants) "C", HOLogic.boolT); val P_unit = Free (singleton (Name.variant_list all_variants) "P", recT0 --> HOLogic.boolT); (*selectors*) val sel_conv_props = map (fn (c, x as Free (_, T)) => mk_sel r_rec0 (c, T) === x) named_vars_more; (*updates*) fun mk_upd_prop i (c, T) = let val x' = Free (singleton (Name.variant_list all_variants) (Long_Name.base_name c ^ "'"), T --> T); val n = parent_fields_len + i; val args' = nth_map n (K (x' $ nth all_vars_more n)) all_vars_more; in mk_upd updateN c x' r_rec0 === mk_rec args' 0 end; val upd_conv_props = map2 mk_upd_prop idxms fields_more; (*induct*) val induct_scheme_prop = fold_rev Logic.all all_vars_more (Trueprop (P $ r_rec0)) ==> Trueprop (P $ r0); val induct_prop = (fold_rev Logic.all all_vars (Trueprop (P_unit $ r_rec_unit0)), Trueprop (P_unit $ r_unit0)); (*surjective*) val surjective_prop = let val args = map (fn (c, Free (_, T)) => mk_sel r0 (c, T)) all_named_vars_more in r0 === mk_rec args 0 end; (*cases*) val cases_scheme_prop = (fold_rev Logic.all all_vars_more ((r0 === r_rec0) ==> Trueprop C), Trueprop C); val cases_prop = fold_rev Logic.all all_vars ((r_unit0 === r_rec_unit0) ==> Trueprop C) ==> Trueprop C; (*split*) val split_meta_prop = let val P = Free (singleton (Name.variant_list all_variants) "P", rec_schemeT0 --> propT); in Logic.mk_equals (Logic.all r0 (P $ r0), fold_rev Logic.all all_vars_more (P $ r_rec0)) end; val split_object_prop = let val ALL = fold_rev (fn (v, T) => fn t => HOLogic.mk_all (v, T, t)) in ALL [dest_Free r0] (P $ r0) === ALL (map dest_Free all_vars_more) (P $ r_rec0) end; val split_ex_prop = let val EX = fold_rev (fn (v, T) => fn t => HOLogic.mk_exists (v, T, t)) in EX [dest_Free r0] (P $ r0) === EX (map dest_Free all_vars_more) (P $ r_rec0) end; (*equality*) val equality_prop = let val s' = Free (rN ^ "'", rec_schemeT0); fun mk_sel_eq (c, Free (_, T)) = mk_sel r0 (c, T) === mk_sel s' (c, T); val seleqs = map mk_sel_eq all_named_vars_more; in Logic.all r0 (Logic.all s' (Logic.list_implies (seleqs, r0 === s'))) end; (* 3rd stage: thms_thy *) val record_ss = get_simpset defs_thy; val sel_upd_ss = simpset_of (put_simpset record_ss defs_ctxt addsimps (sel_defs @ accessor_thms @ upd_defs @ updator_thms)); val (sel_convs, upd_convs) = timeit_msg defs_ctxt "record sel_convs/upd_convs proof:" (fn () => grouped 10 Par_List.map (fn prop => Goal.prove_sorry_global defs_thy [] [] prop (fn {context = ctxt', ...} => ALLGOALS (asm_full_simp_tac (put_simpset sel_upd_ss ctxt')))) (sel_conv_props @ upd_conv_props)) |> chop (length sel_conv_props); val (fold_congs, unfold_congs) = timeit_msg defs_ctxt "record upd fold/unfold congs:" (fn () => let val symdefs = map Thm.symmetric (sel_defs @ upd_defs); val fold_ctxt = put_simpset HOL_basic_ss defs_ctxt addsimps symdefs; val ua_congs = map (Drule.export_without_context o simplify fold_ctxt) upd_acc_cong_assists; in (ua_congs RL [updacc_foldE], ua_congs RL [updacc_unfoldE]) end); val parent_induct = Option.map #induct_scheme (try List.last parents); val induct_scheme = timeit_msg defs_ctxt "record induct_scheme proof:" (fn () => Goal.prove_sorry_global defs_thy [] [] induct_scheme_prop (fn {context = ctxt', ...} => EVERY [case parent_induct of NONE => all_tac | SOME ind => try_param_tac ctxt' rN ind 1, try_param_tac ctxt' rN ext_induct 1, asm_simp_tac (put_simpset HOL_basic_ss ctxt') 1])); val induct = timeit_msg defs_ctxt "record induct proof:" (fn () => Goal.prove_sorry_global defs_thy [] [#1 induct_prop] (#2 induct_prop) (fn {context = ctxt', prems, ...} => try_param_tac ctxt' rN induct_scheme 1 THEN try_param_tac ctxt' "more" @{thm unit.induct} 1 THEN resolve_tac ctxt' prems 1)); val surjective = timeit_msg defs_ctxt "record surjective proof:" (fn () => Goal.prove_sorry_global defs_thy [] [] surjective_prop (fn {context = ctxt', ...} => EVERY [resolve_tac ctxt' [surject_assist_idE] 1, simp_tac (put_simpset HOL_basic_ss ctxt' addsimps ext_defs) 1, REPEAT (Iso_Tuple_Support.iso_tuple_intros_tac ctxt' 1 ORELSE (resolve_tac ctxt' [surject_assistI] 1 THEN simp_tac (put_simpset (get_sel_upd_defs defs_thy) ctxt' addsimps (sel_defs @ @{thms o_assoc id_apply id_o o_id})) 1))])); val cases_scheme = timeit_msg defs_ctxt "record cases_scheme proof:" (fn () => Goal.prove_sorry_global defs_thy [] [#1 cases_scheme_prop] (#2 cases_scheme_prop) (fn {context = ctxt', prems, ...} => resolve_tac ctxt' prems 1 THEN resolve_tac ctxt' [surjective] 1)); val cases = timeit_msg defs_ctxt "record cases proof:" (fn () => Goal.prove_sorry_global defs_thy [] [] cases_prop (fn {context = ctxt', ...} => try_param_tac ctxt' rN cases_scheme 1 THEN ALLGOALS (asm_full_simp_tac (put_simpset HOL_basic_ss ctxt' addsimps @{thms unit_all_eq1})))); val split_meta = timeit_msg defs_ctxt "record split_meta proof:" (fn () => Goal.prove_sorry_global defs_thy [] [] split_meta_prop (fn {context = ctxt', ...} => EVERY1 [resolve_tac ctxt' @{thms equal_intr_rule}, Goal.norm_hhf_tac ctxt', eresolve_tac ctxt' @{thms meta_allE}, assume_tac ctxt', resolve_tac ctxt' [@{thm prop_subst} OF [surjective]], REPEAT o eresolve_tac ctxt' @{thms meta_allE}, assume_tac ctxt'])); val split_object = timeit_msg defs_ctxt "record split_object proof:" (fn () => Goal.prove_sorry_global defs_thy [] [] split_object_prop (fn {context = ctxt', ...} => resolve_tac ctxt' [@{lemma "Trueprop A \ Trueprop B \ A = B" by (rule iffI) unfold}] 1 THEN rewrite_goals_tac ctxt' @{thms atomize_all [symmetric]} THEN resolve_tac ctxt' [split_meta] 1)); val split_ex = timeit_msg defs_ctxt "record split_ex proof:" (fn () => Goal.prove_sorry_global defs_thy [] [] split_ex_prop (fn {context = ctxt', ...} => simp_tac (put_simpset HOL_basic_ss ctxt' addsimps (@{lemma "\x. P x \ \ (\x. \ P x)" by simp} :: @{thms not_not Not_eq_iff})) 1 THEN resolve_tac ctxt' [split_object] 1)); val equality = timeit_msg defs_ctxt "record equality proof:" (fn () => Goal.prove_sorry_global defs_thy [] [] equality_prop (fn {context = ctxt', ...} => asm_full_simp_tac (put_simpset record_ss ctxt' addsimps (split_meta :: sel_convs)) 1)); val ([(_, sel_convs'), (_, upd_convs'), (_, sel_defs'), (_, upd_defs'), (_, fold_congs'), (_, unfold_congs'), (_, splits' as [split_meta', split_object', split_ex']), (_, derived_defs'), (_, [surjective']), (_, [equality']), (_, [induct_scheme']), (_, [induct']), (_, [cases_scheme']), (_, [cases'])], thms_thy) = defs_thy |> Code.declare_default_eqns_global (map (rpair true) derived_defs) |> Global_Theory.note_thmss "" [((Binding.name "select_convs", []), [(sel_convs, [])]), ((Binding.name "update_convs", []), [(upd_convs, [])]), ((Binding.name "select_defs", []), [(sel_defs, [])]), ((Binding.name "update_defs", []), [(upd_defs, [])]), ((Binding.name "fold_congs", []), [(fold_congs, [])]), ((Binding.name "unfold_congs", []), [(unfold_congs, [])]), ((Binding.name "splits", []), [([split_meta, split_object, split_ex], [])]), ((Binding.name "defs", []), [(derived_defs, [])]), ((Binding.name "surjective", []), [([surjective], [])]), ((Binding.name "equality", []), [([equality], [])]), ((Binding.name "induct_scheme", induct_type_global (suffix schemeN name)), [([induct_scheme], [])]), ((Binding.name "induct", induct_type_global name), [([induct], [])]), ((Binding.name "cases_scheme", cases_type_global (suffix schemeN name)), [([cases_scheme], [])]), ((Binding.name "cases", cases_type_global name), [([cases], [])])]; val sel_upd_simps = sel_convs' @ upd_convs'; val sel_upd_defs = sel_defs' @ upd_defs'; val depth = parent_len + 1; val ([(_, simps'), (_, iffs')], thms_thy') = thms_thy |> Global_Theory.note_thmss "" [((Binding.name "simps", [Simplifier.simp_add]), [(sel_upd_simps, [])]), ((Binding.name "iffs", [iff_add]), [([ext_inject], [])])]; val info = make_info alphas parent fields extension ext_induct ext_inject ext_surjective ext_split ext_def sel_convs' upd_convs' sel_defs' upd_defs' fold_congs' unfold_congs' splits' derived_defs' surjective' equality' induct_scheme' induct' cases_scheme' cases' simps' iffs'; val final_thy = thms_thy' |> put_record name info |> put_sel_upd names full_moreN depth sel_upd_simps sel_upd_defs |> add_equalities extension_id equality' |> add_extinjects ext_inject |> add_extsplit extension_name ext_split |> add_splits extension_id (split_meta', split_object', split_ex', induct_scheme') |> add_extfields extension_name (fields @ [(full_moreN, moreT)]) |> add_fieldext (extension_name, snd extension) names |> add_code ext_tyco vs extT ext simps' ext_inject |> add_ctr_sugar (Const ext) cases_scheme' ext_inject sel_convs' |> fold add_spec_rule (sel_convs' @ upd_convs' @ derived_defs') |> Sign.restore_naming thy0; in final_thy end; (* add_record *) local fun read_parent NONE ctxt = (NONE, ctxt) | read_parent (SOME raw_T) ctxt = (case Proof_Context.read_typ_abbrev ctxt raw_T of Type (name, Ts) => (SOME (Ts, name), fold Variable.declare_typ Ts ctxt) | T => error ("Bad parent record specification: " ^ Syntax.string_of_typ ctxt T)); fun read_fields raw_fields ctxt = let val Ts = Syntax.read_typs ctxt (map (fn (_, raw_T, _) => raw_T) raw_fields); val fields = map2 (fn (x, _, mx) => fn T => (x, T, mx)) raw_fields Ts; val ctxt' = fold Variable.declare_typ Ts ctxt; in (fields, ctxt') end; in fun add_record overloaded (params, binding) raw_parent raw_fields thy = let val ctxt = Proof_Context.init_global thy; fun cert_typ T = Type.no_tvars (Proof_Context.cert_typ ctxt T) handle TYPE (msg, _, _) => error msg; (* specification *) val parent = Option.map (apfst (map cert_typ)) raw_parent handle ERROR msg => cat_error msg ("The error(s) above occurred in parent record specification"); val parent_args = (case parent of SOME (Ts, _) => Ts | NONE => []); val parents = get_parent_info thy parent; val bfields = raw_fields |> map (fn (x, raw_T, mx) => (x, cert_typ raw_T, mx) handle ERROR msg => cat_error msg ("The error(s) above occurred in record field " ^ Binding.print x)); (* errors *) val name = Sign.full_name thy binding; val err_dup_record = if is_none (get_info thy name) then [] else ["Duplicate definition of record " ^ quote name]; val spec_frees = fold Term.add_tfreesT (parent_args @ map #2 bfields) []; val err_extra_frees = (case subtract (op =) params spec_frees of [] => [] | extras => ["Extra free type variable(s) " ^ commas (map (Syntax.string_of_typ ctxt o TFree) extras)]); val err_no_fields = if null bfields then ["No fields present"] else []; val err_dup_fields = (case duplicates Binding.eq_name (map #1 bfields) of [] => [] | dups => ["Duplicate field(s) " ^ commas (map Binding.print dups)]); val err_bad_fields = if forall (not_equal moreN o Binding.name_of o #1) bfields then [] else ["Illegal field name " ^ quote moreN]; val errs = err_dup_record @ err_extra_frees @ err_no_fields @ err_dup_fields @ err_bad_fields; val _ = if null errs then () else error (cat_lines errs); in thy |> definition overloaded (params, binding) parent parents bfields end handle ERROR msg => cat_error msg ("Failed to define record " ^ Binding.print binding); fun add_record_cmd overloaded (raw_params, binding) raw_parent raw_fields thy = let val ctxt = Proof_Context.init_global thy; val params = map (apsnd (Typedecl.read_constraint ctxt)) raw_params; val ctxt1 = fold (Variable.declare_typ o TFree) params ctxt; val (parent, ctxt2) = read_parent raw_parent ctxt1; val (fields, ctxt3) = read_fields raw_fields ctxt2; val params' = map (Proof_Context.check_tfree ctxt3) params; in thy |> add_record overloaded (params', binding) parent fields end; end; (* printing *) local fun the_parent_recT (Type (parent, [Type (_, [unit as Type (_,[])])])) = Type (parent, [unit]) | the_parent_recT (Type (extT, [T])) = Type (extT, [the_parent_recT T]) | the_parent_recT T = raise TYPE ("Not a unit record scheme with parent: ", [T], []) in fun pretty_recT ctxt typ = let val thy = Proof_Context.theory_of ctxt val (fs, (_, moreT)) = get_recT_fields thy typ val _ = if moreT = HOLogic.unitT then () else raise TYPE ("Not a unit record scheme: ", [typ], []) val parent = if length (dest_recTs typ) >= 2 then SOME (the_parent_recT typ) else NONE val pfs = case parent of SOME p => fst (get_recT_fields thy p) | NONE => [] val fs' = drop (length pfs) fs fun pretty_field (name, typ) = Pretty.block [ Syntax.pretty_term ctxt (Const (name, typ)), Pretty.brk 1, Pretty.str "::", Pretty.brk 1, Syntax.pretty_typ ctxt typ ] in Pretty.block (Library.separate (Pretty.brk 1) ([Pretty.keyword1 "record", Syntax.pretty_typ ctxt typ, Pretty.str "="] @ (case parent of SOME p => [Syntax.pretty_typ ctxt p, Pretty.str "+"] | NONE => [])) @ Pretty.fbrk :: Pretty.fbreaks (map pretty_field fs')) end end fun string_of_record ctxt s = let val T = Syntax.read_typ ctxt s in Pretty.string_of (pretty_recT ctxt T) handle TYPE _ => error ("Unknown record: " ^ Syntax.string_of_typ ctxt T) end val print_record = let fun print_item string_of (modes, arg) = Toplevel.keep (fn state => Print_Mode.with_modes modes (fn () => Output.writeln (string_of state arg)) ()); in print_item (string_of_record o Toplevel.context_of) end (* outer syntax *) val _ = Outer_Syntax.command \<^command_keyword>\record\ "define extensible record" (Parse_Spec.overloaded -- (Parse.type_args_constrained -- Parse.binding) -- (\<^keyword>\=\ |-- Scan.option (Parse.typ --| \<^keyword>\+\) -- Scan.repeat1 Parse.const_binding) >> (fn ((overloaded, x), (y, z)) => Toplevel.theory (add_record_cmd {overloaded = overloaded} x y z))); val opt_modes = Scan.optional (\<^keyword>\(\ |-- Parse.!!! (Scan.repeat1 Parse.name --| \<^keyword>\)\)) [] val _ = Outer_Syntax.command \<^command_keyword>\print_record\ "print record definiton" (opt_modes -- Parse.typ >> print_record); end diff --git a/src/Pure/ex/Def.thy b/src/Pure/ex/Def.thy --- a/src/Pure/ex/Def.thy +++ b/src/Pure/ex/Def.thy @@ -1,106 +1,109 @@ (* Title: Pure/ex/Def.thy Author: Makarius Primitive constant definition, without fact definition; automatic expansion via Simplifier (simproc). *) theory Def imports Pure keywords "def" :: thy_defn begin ML \ signature DEF = sig val get_def: Proof.context -> cterm -> thm option val def: (binding * typ option * mixfix) option -> (binding * typ option * mixfix) list -> term -> local_theory -> term * local_theory val def_cmd: (binding * string option * mixfix) option -> (binding * string option * mixfix) list -> string -> local_theory -> term * local_theory end; structure Def: DEF = struct (* context data *) type def = {lhs: term, eq: thm}; val eq_def : def * def -> bool = op aconv o apply2 #lhs; fun transform_def phi ({lhs, eq}: def) = {lhs = Morphism.term phi lhs, eq = Morphism.thm phi eq}; +fun trim_context_def ({lhs, eq}: def) = + {lhs = lhs, eq = Thm.trim_context eq}; + structure Data = Generic_Data ( type T = def Item_Net.T; val empty : T = Item_Net.init eq_def (single o #lhs); val merge = Item_Net.merge; ); fun declare_def lhs eq lthy = let val def0: def = {lhs = lhs, eq = Thm.trim_context eq} in lthy |> Local_Theory.declaration {syntax = false, pervasive = true, pos = \<^here>} (fn phi => fn context => - let val psi = Morphism.set_trim_context'' context phi - in (Data.map o Item_Net.update) (transform_def psi def0) context end) + let val def' = def0 |> transform_def phi |> trim_context_def + in (Data.map o Item_Net.update) def' context end) end; fun get_def ctxt ct = let val thy = Proof_Context.theory_of ctxt; val data = Data.get (Context.Proof ctxt); val t = Thm.term_of ct; fun match_def {lhs, eq} = if Pattern.matches thy (lhs, t) then let val inst = Thm.match (Thm.cterm_of ctxt lhs, ct) in SOME (Thm.instantiate inst (Thm.transfer thy eq)) end else NONE; in Item_Net.retrieve_matching data t |> get_first match_def end; (* simproc setup *) val _ = (Theory.setup o Named_Target.theory_map) (Simplifier.define_simproc \<^binding>\expand_def\ {lhss = [Free ("x", TFree ("'a", []))], proc = K get_def}); (* Isar command *) fun gen_def prep_spec raw_var raw_params raw_spec lthy = let val ((vars, xs, get_pos, spec), _) = lthy |> prep_spec (the_list raw_var) raw_params [] raw_spec; val (((x, _), rhs), prove) = Local_Defs.derived_def lthy get_pos {conditional = false} spec; val _ = Name.reject_internal (x, []); val (b, mx) = (case (vars, xs) of ([], []) => (Binding.make (x, (case get_pos x of [] => Position.none | p :: _ => p)), NoSyn) | ([(b, _, mx)], [y]) => if x = y then (b, mx) else error ("Head of definition " ^ quote x ^ " differs from declaration " ^ quote y ^ Position.here (Binding.pos_of b))); val ((lhs, (_, eq)), lthy') = lthy |> Local_Theory.define_internal ((b, mx), (Binding.empty_atts, rhs)); (*sanity check for original specification*) val _: thm = prove lthy' eq; in (lhs, declare_def lhs eq lthy') end; val def = gen_def Specification.check_spec_open; val def_cmd = gen_def Specification.read_spec_open; val _ = Outer_Syntax.local_theory \<^command_keyword>\def\ "primitive constant definition, without fact definition" (Scan.option Parse_Spec.constdecl -- Parse.prop -- Parse.for_fixes >> (fn ((decl, spec), params) => #2 o def_cmd decl params spec)); end; \ end diff --git a/src/Pure/morphism.ML b/src/Pure/morphism.ML --- a/src/Pure/morphism.ML +++ b/src/Pure/morphism.ML @@ -1,260 +1,267 @@ (* Title: Pure/morphism.ML Author: Makarius Abstract morphisms on formal entities. *) infix 1 $> signature BASIC_MORPHISM = sig type morphism val $> : morphism * morphism -> morphism end signature MORPHISM = sig include BASIC_MORPHISM exception MORPHISM of string * exn val the_theory: theory option -> theory val set_context: theory -> morphism -> morphism val set_context': Proof.context -> morphism -> morphism val set_context'': Context.generic -> morphism -> morphism val reset_context: morphism -> morphism val morphism: string -> {binding: (theory option -> binding -> binding) list, typ: (theory option -> typ -> typ) list, term: (theory option -> term -> term) list, fact: (theory option -> thm list -> thm list) list} -> morphism val is_identity: morphism -> bool val is_empty: morphism -> bool val pretty: morphism -> Pretty.T val binding: morphism -> binding -> binding val binding_prefix: morphism -> (string * bool) list val typ: morphism -> typ -> typ val term: morphism -> term -> term val fact: morphism -> thm list -> thm list val thm: morphism -> thm -> thm val cterm: morphism -> cterm -> cterm val identity: morphism val default: morphism option -> morphism val compose: morphism -> morphism -> morphism type 'a entity val entity: (morphism -> 'a) -> 'a entity val entity_reset_context: 'a entity -> 'a entity val entity_set_context: theory -> 'a entity -> 'a entity val entity_set_context': Proof.context -> 'a entity -> 'a entity val entity_set_context'': Context.generic -> 'a entity -> 'a entity val transform: morphism -> 'a entity -> 'a entity val transform_reset_context: morphism -> 'a entity -> 'a entity val form: 'a entity -> 'a val form_entity: (morphism -> 'a) -> 'a + val form_context: theory -> (theory -> 'a) entity -> 'a + val form_context': Proof.context -> (Proof.context -> 'a) entity -> 'a + val form_context'': Context.generic -> (Context.generic -> 'a) entity -> 'a type declaration = morphism -> Context.generic -> Context.generic type declaration_entity = (Context.generic -> Context.generic) entity val binding_morphism: string -> (binding -> binding) -> morphism val typ_morphism': string -> (theory -> typ -> typ) -> morphism val typ_morphism: string -> (typ -> typ) -> morphism val term_morphism': string -> (theory -> term -> term) -> morphism val term_morphism: string -> (term -> term) -> morphism val fact_morphism': string -> (theory -> thm list -> thm list) -> morphism val fact_morphism: string -> (thm list -> thm list) -> morphism val thm_morphism': string -> (theory -> thm -> thm) -> morphism val thm_morphism: string -> (thm -> thm) -> morphism val transfer_morphism: theory -> morphism val transfer_morphism': Proof.context -> morphism val transfer_morphism'': Context.generic -> morphism val trim_context_morphism: morphism val set_trim_context: theory -> morphism -> morphism val set_trim_context': Proof.context -> morphism -> morphism val set_trim_context'': Context.generic -> morphism -> morphism val instantiate_frees_morphism: ctyp TFrees.table * cterm Frees.table -> morphism val instantiate_morphism: ctyp TVars.table * cterm Vars.table -> morphism end; structure Morphism: MORPHISM = struct (* named functions *) type 'a funs = (string * (theory option -> 'a -> 'a)) list; exception MORPHISM of string * exn; fun app context (name, f) x = f context x handle exn => if Exn.is_interrupt exn then Exn.reraise exn else raise MORPHISM (name, exn); (* optional context *) fun the_theory (SOME thy) = thy | the_theory NONE = raise Fail "Morphism lacks theory context"; fun join_transfer (SOME thy) = Thm.join_transfer thy | join_transfer NONE = I; val join_context = join_options Context.join_certificate_theory; (* type morphism *) datatype morphism = Morphism of {context: theory option, names: string list, binding: binding funs, typ: typ funs, term: term funs, fact: thm list funs}; fun rep (Morphism args) = args; fun apply which phi = let val args = rep phi in fold_rev (app (#context args)) (which args) end; fun put_context context (Morphism {context = _, names, binding, typ, term, fact}) = Morphism {context = context, names = names, binding = binding, typ = typ, term = term, fact = fact}; val set_context = put_context o SOME; val set_context' = set_context o Proof_Context.theory_of; val set_context'' = set_context o Context.theory_of; val reset_context = put_context NONE; fun morphism a {binding, typ, term, fact} = Morphism { context = NONE, names = if a = "" then [] else [a], binding = map (pair a) binding, typ = map (pair a) typ, term = map (pair a) term, fact = map (pair a) fact}; (*syntactic test only!*) fun is_identity (Morphism {context = _, names, binding, typ, term, fact}) = null names andalso null binding andalso null typ andalso null term andalso null fact; fun is_empty phi = is_none (#context (rep phi)) andalso is_identity phi; fun pretty phi = Pretty.enum ";" "{" "}" (map Pretty.str (rev (#names (rep phi)))); val _ = ML_system_pp (fn _ => fn _ => Pretty.to_polyml o pretty); val binding = apply #binding; fun binding_prefix morph = Binding.name "x" |> binding morph |> Binding.prefix_of; val typ = apply #typ; val term = apply #term; fun fact phi = map (join_transfer (#context (rep phi))) #> apply #fact phi; val thm = singleton o fact; val cterm = Drule.cterm_rule o thm; (* morphism combinators *) val identity = morphism "" {binding = [], typ = [], term = [], fact = []}; val default = the_default identity; fun compose phi1 phi2 = if is_empty phi1 then phi2 else if is_empty phi2 then phi1 else let val {context = context1, names = names1, binding = binding1, typ = typ1, term = term1, fact = fact1} = rep phi1; val {context = context2, names = names2, binding = binding2, typ = typ2, term = term2, fact = fact2} = rep phi2; in Morphism { context = join_context (context1, context2), names = names1 @ names2, binding = binding1 @ binding2, typ = typ1 @ typ2, term = term1 @ term2, fact = fact1 @ fact2} end; fun phi1 $> phi2 = compose phi2 phi1; (* abstract entities *) datatype 'a entity = Entity of (morphism -> 'a) * morphism; fun entity f = Entity (f, identity); fun entity_morphism g (Entity (f, phi)) = Entity (f, g phi); fun entity_reset_context a = entity_morphism reset_context a; fun entity_set_context thy a = entity_morphism (set_context thy) a; fun entity_set_context' ctxt a = entity_morphism (set_context' ctxt) a; fun entity_set_context'' context a = entity_morphism (set_context'' context) a; fun transform phi = entity_morphism (compose phi); fun transform_reset_context phi = entity_morphism (reset_context o compose phi); fun form (Entity (f, phi)) = f phi; fun form_entity f = f identity; +fun form_context thy x = form (entity_set_context thy x) thy; +fun form_context' ctxt x = form (entity_set_context' ctxt x) ctxt; +fun form_context'' context x = form (entity_set_context'' context x) context; + type declaration = morphism -> Context.generic -> Context.generic; type declaration_entity = (Context.generic -> Context.generic) entity; (* concrete morphisms *) fun binding_morphism a binding = morphism a {binding = [K binding], typ = [], term = [], fact = []}; fun typ_morphism' a typ = morphism a {binding = [], typ = [typ o the_theory], term = [], fact = []}; fun typ_morphism a typ = morphism a {binding = [], typ = [K typ], term = [], fact = []}; fun term_morphism' a term = morphism a {binding = [], typ = [], term = [term o the_theory], fact = []}; fun term_morphism a term = morphism a {binding = [], typ = [], term = [K term], fact = []}; fun fact_morphism' a fact = morphism a {binding = [], typ = [], term = [], fact = [fact o the_theory]}; fun fact_morphism a fact = morphism a {binding = [], typ = [], term = [], fact = [K fact]}; fun thm_morphism' a thm = morphism a {binding = [], typ = [], term = [], fact = [map o thm o the_theory]}; fun thm_morphism a thm = morphism a {binding = [], typ = [], term = [], fact = [K (map thm)]}; fun transfer_morphism thy = fact_morphism "transfer" I |> set_context thy; val transfer_morphism' = transfer_morphism o Proof_Context.theory_of; val transfer_morphism'' = transfer_morphism o Context.theory_of; val trim_context_morphism = thm_morphism "trim_context" Thm.trim_context; fun set_trim_context thy phi = set_context thy phi $> trim_context_morphism; val set_trim_context' = set_trim_context o Proof_Context.theory_of; val set_trim_context'' = set_trim_context o Context.theory_of; (* instantiate *) fun instantiate_frees_morphism (cinstT, cinst) = if TFrees.is_empty cinstT andalso Frees.is_empty cinst then identity else let val instT = TFrees.map (K Thm.typ_of) cinstT; val inst = Frees.map (K Thm.term_of) cinst; in morphism "instantiate_frees" {binding = [], typ = if TFrees.is_empty instT then [] else [K (Term_Subst.instantiateT_frees instT)], term = [K (Term_Subst.instantiate_frees (instT, inst))], fact = [K (map (Thm.instantiate_frees (cinstT, cinst)))]} end; fun instantiate_morphism (cinstT, cinst) = if TVars.is_empty cinstT andalso Vars.is_empty cinst then identity else let val instT = TVars.map (K Thm.typ_of) cinstT; val inst = Vars.map (K Thm.term_of) cinst; in morphism "instantiate" {binding = [], typ = if TVars.is_empty instT then [] else [K (Term_Subst.instantiateT instT)], term = [K (Term_Subst.instantiate (instT, inst))], fact = [K (map (Thm.instantiate (cinstT, cinst)))]} end; end; structure Basic_Morphism: BASIC_MORPHISM = Morphism; open Basic_Morphism; diff --git a/src/Pure/raw_simplifier.ML b/src/Pure/raw_simplifier.ML --- a/src/Pure/raw_simplifier.ML +++ b/src/Pure/raw_simplifier.ML @@ -1,1456 +1,1464 @@ (* Title: Pure/raw_simplifier.ML Author: Tobias Nipkow and Stefan Berghofer, TU Muenchen Higher-order Simplification. *) infix 4 addsimps delsimps addsimprocs delsimprocs setloop addloop delloop setSSolver addSSolver setSolver addSolver; signature BASIC_RAW_SIMPLIFIER = sig val simp_depth_limit: int Config.T val simp_trace_depth_limit: int Config.T val simp_debug: bool Config.T val simp_trace: bool Config.T type cong_name = bool * string type rrule val mk_rrules: Proof.context -> thm list -> rrule list val eq_rrule: rrule * rrule -> bool type proc type solver val mk_solver: string -> (Proof.context -> int -> tactic) -> solver type simpset val empty_ss: simpset val merge_ss: simpset * simpset -> simpset val dest_ss: simpset -> {simps: (string * thm) list, procs: (string * term list) list, congs: (cong_name * thm) list, weak_congs: cong_name list, loopers: string list, unsafe_solvers: string list, safe_solvers: string list} type simproc val eq_simproc: simproc * simproc -> bool val cert_simproc: theory -> string -> {lhss: term list, proc: (Proof.context -> cterm -> thm option) Morphism.entity} -> simproc val transform_simproc: morphism -> simproc -> simproc + val trim_context_simproc: simproc -> simproc val simpset_of: Proof.context -> simpset val put_simpset: simpset -> Proof.context -> Proof.context val simpset_map: Proof.context -> (Proof.context -> Proof.context) -> simpset -> simpset val map_theory_simpset: (Proof.context -> Proof.context) -> theory -> theory val empty_simpset: Proof.context -> Proof.context val clear_simpset: Proof.context -> Proof.context val addsimps: Proof.context * thm list -> Proof.context val delsimps: Proof.context * thm list -> Proof.context val addsimprocs: Proof.context * simproc list -> Proof.context val delsimprocs: Proof.context * simproc list -> Proof.context val setloop: Proof.context * (Proof.context -> int -> tactic) -> Proof.context val addloop: Proof.context * (string * (Proof.context -> int -> tactic)) -> Proof.context val delloop: Proof.context * string -> Proof.context val setSSolver: Proof.context * solver -> Proof.context val addSSolver: Proof.context * solver -> Proof.context val setSolver: Proof.context * solver -> Proof.context val addSolver: Proof.context * solver -> Proof.context val rewrite_rule: Proof.context -> thm list -> thm -> thm val rewrite_goals_rule: Proof.context -> thm list -> thm -> thm val rewrite_goals_tac: Proof.context -> thm list -> tactic val rewrite_goal_tac: Proof.context -> thm list -> int -> tactic val prune_params_tac: Proof.context -> tactic val fold_rule: Proof.context -> thm list -> thm -> thm val fold_goals_tac: Proof.context -> thm list -> tactic val norm_hhf: Proof.context -> thm -> thm val norm_hhf_protect: Proof.context -> thm -> thm end; signature RAW_SIMPLIFIER = sig include BASIC_RAW_SIMPLIFIER exception SIMPLIFIER of string * thm list type trace_ops val set_trace_ops: trace_ops -> theory -> theory val subgoal_tac: Proof.context -> int -> tactic val loop_tac: Proof.context -> int -> tactic val solvers: Proof.context -> solver list * solver list val map_ss: (Proof.context -> Proof.context) -> Context.generic -> Context.generic val prems_of: Proof.context -> thm list val add_simp: thm -> Proof.context -> Proof.context val del_simp: thm -> Proof.context -> Proof.context val flip_simp: thm -> Proof.context -> Proof.context val init_simpset: thm list -> Proof.context -> Proof.context val add_eqcong: thm -> Proof.context -> Proof.context val del_eqcong: thm -> Proof.context -> Proof.context val add_cong: thm -> Proof.context -> Proof.context val del_cong: thm -> Proof.context -> Proof.context val mksimps: Proof.context -> thm -> thm list val set_mksimps: (Proof.context -> thm -> thm list) -> Proof.context -> Proof.context val set_mkcong: (Proof.context -> thm -> thm) -> Proof.context -> Proof.context val set_mksym: (Proof.context -> thm -> thm option) -> Proof.context -> Proof.context val set_mkeqTrue: (Proof.context -> thm -> thm option) -> Proof.context -> Proof.context val set_term_ord: term ord -> Proof.context -> Proof.context val set_subgoaler: (Proof.context -> int -> tactic) -> Proof.context -> Proof.context val solver: Proof.context -> solver -> int -> tactic val default_mk_sym: Proof.context -> thm -> thm option val add_prems: thm list -> Proof.context -> Proof.context val set_reorient: (Proof.context -> term list -> term -> term -> bool) -> Proof.context -> Proof.context val set_solvers: solver list -> Proof.context -> Proof.context val rewrite_cterm: bool * bool * bool -> (Proof.context -> thm -> thm option) -> Proof.context -> conv val rewrite_term: theory -> thm list -> (term -> term option) list -> term -> term val rewrite_thm: bool * bool * bool -> (Proof.context -> thm -> thm option) -> Proof.context -> thm -> thm val generic_rewrite_goal_tac: bool * bool * bool -> (Proof.context -> tactic) -> Proof.context -> int -> tactic val rewrite: Proof.context -> bool -> thm list -> conv end; structure Raw_Simplifier: RAW_SIMPLIFIER = struct (** datatype simpset **) (* congruence rules *) type cong_name = bool * string; fun cong_name (Const (a, _)) = SOME (true, a) | cong_name (Free (a, _)) = SOME (false, a) | cong_name _ = NONE; structure Congtab = Table(type key = cong_name val ord = prod_ord bool_ord fast_string_ord); (* rewrite rules *) type rrule = {thm: thm, (*the rewrite rule*) name: string, (*name of theorem from which rewrite rule was extracted*) lhs: term, (*the left-hand side*) elhs: cterm, (*the eta-contracted lhs*) extra: bool, (*extra variables outside of elhs*) fo: bool, (*use first-order matching*) perm: bool}; (*the rewrite rule is permutative*) fun trim_context_rrule ({thm, name, lhs, elhs, extra, fo, perm}: rrule) = {thm = Thm.trim_context thm, name = name, lhs = lhs, elhs = Thm.trim_context_cterm elhs, extra = extra, fo = fo, perm = perm}; (* Remarks: - elhs is used for matching, lhs only for preservation of bound variable names; - fo is set iff either elhs is first-order (no Var is applied), in which case fo-matching is complete, or elhs is not a pattern, in which case there is nothing better to do; *) fun eq_rrule ({thm = thm1, ...}: rrule, {thm = thm2, ...}: rrule) = Thm.eq_thm_prop (thm1, thm2); (* FIXME: it seems that the conditions on extra variables are too liberal if prems are nonempty: does solving the prems really guarantee instantiation of all its Vars? Better: a dynamic check each time a rule is applied. *) fun rewrite_rule_extra_vars prems elhs erhs = let val elhss = elhs :: prems; val tvars = TVars.build (fold TVars.add_tvars elhss); val vars = Vars.build (fold Vars.add_vars elhss); in erhs |> Term.exists_type (Term.exists_subtype (fn TVar v => not (TVars.defined tvars v) | _ => false)) orelse erhs |> Term.exists_subterm (fn Var v => not (Vars.defined vars v) | _ => false) end; fun rrule_extra_vars elhs thm = rewrite_rule_extra_vars [] (Thm.term_of elhs) (Thm.full_prop_of thm); fun mk_rrule2 {thm, name, lhs, elhs, perm} = let val t = Thm.term_of elhs; val fo = Pattern.first_order t orelse not (Pattern.pattern t); val extra = rrule_extra_vars elhs thm; in {thm = thm, name = name, lhs = lhs, elhs = elhs, extra = extra, fo = fo, perm = perm} end; (*simple test for looping rewrite rules and stupid orientations*) fun default_reorient ctxt prems lhs rhs = rewrite_rule_extra_vars prems lhs rhs orelse is_Var (head_of lhs) orelse (* turns t = x around, which causes a headache if x is a local variable - usually it is very useful :-( is_Free rhs andalso not(is_Free lhs) andalso not(Logic.occs(rhs,lhs)) andalso not(exists_subterm is_Var lhs) orelse *) exists (fn t => Logic.occs (lhs, t)) (rhs :: prems) orelse null prems andalso Pattern.matches (Proof_Context.theory_of ctxt) (lhs, rhs) (*the condition "null prems" is necessary because conditional rewrites with extra variables in the conditions may terminate although the rhs is an instance of the lhs; example: ?m < ?n \ f ?n \ f ?m *) orelse is_Const lhs andalso not (is_Const rhs); (* simplification procedures *) datatype proc = Proc of {name: string, lhs: term, - proc: Proof.context -> cterm -> thm option, + proc: (Proof.context -> cterm -> thm option) Morphism.entity, stamp: stamp}; fun eq_proc (Proc {stamp = stamp1, ...}, Proc {stamp = stamp2, ...}) = stamp1 = stamp2; (* solvers *) datatype solver = Solver of {name: string, solver: Proof.context -> int -> tactic, id: stamp}; fun mk_solver name solver = Solver {name = name, solver = solver, id = stamp ()}; fun solver_name (Solver {name, ...}) = name; fun solver ctxt (Solver {solver = tac, ...}) = tac ctxt; fun eq_solver (Solver {id = id1, ...}, Solver {id = id2, ...}) = (id1 = id2); (* simplification sets *) (*A simpset contains data required during conversion: rules: discrimination net of rewrite rules; prems: current premises; depth: simp_depth and exceeded flag; congs: association list of congruence rules and a list of `weak' congruence constants. A congruence is `weak' if it avoids normalization of some argument. procs: discrimination net of simplification procedures (functions that prove rewrite rules on the fly); mk_rews: mk: turn simplification thms into rewrite rules; mk_cong: prepare congruence rules; mk_sym: turn \ around; mk_eq_True: turn P into P \ True; term_ord: for ordered rewriting;*) datatype simpset = Simpset of {rules: rrule Net.net, prems: thm list, depth: int * bool Unsynchronized.ref} * {congs: thm Congtab.table * cong_name list, procs: proc Net.net, mk_rews: {mk: Proof.context -> thm -> thm list, mk_cong: Proof.context -> thm -> thm, mk_sym: Proof.context -> thm -> thm option, mk_eq_True: Proof.context -> thm -> thm option, reorient: Proof.context -> term list -> term -> term -> bool}, term_ord: term ord, subgoal_tac: Proof.context -> int -> tactic, loop_tacs: (string * (Proof.context -> int -> tactic)) list, solvers: solver list * solver list}; fun internal_ss (Simpset (_, ss2)) = ss2; fun make_ss1 (rules, prems, depth) = {rules = rules, prems = prems, depth = depth}; fun map_ss1 f {rules, prems, depth} = make_ss1 (f (rules, prems, depth)); fun make_ss2 (congs, procs, mk_rews, term_ord, subgoal_tac, loop_tacs, solvers) = {congs = congs, procs = procs, mk_rews = mk_rews, term_ord = term_ord, subgoal_tac = subgoal_tac, loop_tacs = loop_tacs, solvers = solvers}; fun map_ss2 f {congs, procs, mk_rews, term_ord, subgoal_tac, loop_tacs, solvers} = make_ss2 (f (congs, procs, mk_rews, term_ord, subgoal_tac, loop_tacs, solvers)); fun make_simpset (args1, args2) = Simpset (make_ss1 args1, make_ss2 args2); fun dest_ss (Simpset ({rules, ...}, {congs, procs, loop_tacs, solvers, ...})) = {simps = Net.entries rules |> map (fn {name, thm, ...} => (name, thm)), procs = Net.entries procs |> map (fn Proc {name, lhs, stamp, ...} => ((name, lhs), stamp)) |> partition_eq (eq_snd op =) |> map (fn ps => (fst (fst (hd ps)), map (snd o fst) ps)), congs = congs |> fst |> Congtab.dest, weak_congs = congs |> snd, loopers = map fst loop_tacs, unsafe_solvers = map solver_name (#1 solvers), safe_solvers = map solver_name (#2 solvers)}; (* empty *) fun init_ss depth mk_rews term_ord subgoal_tac solvers = make_simpset ((Net.empty, [], depth), ((Congtab.empty, []), Net.empty, mk_rews, term_ord, subgoal_tac, [], solvers)); fun default_mk_sym _ th = SOME (th RS Drule.symmetric_thm); val empty_ss = init_ss (0, Unsynchronized.ref false) {mk = fn _ => fn th => if can Logic.dest_equals (Thm.concl_of th) then [th] else [], mk_cong = K I, mk_sym = default_mk_sym, mk_eq_True = K (K NONE), reorient = default_reorient} Term_Ord.term_ord (K (K no_tac)) ([], []); (* merge *) (*NOTE: ignores some fields of 2nd simpset*) fun merge_ss (ss1, ss2) = if pointer_eq (ss1, ss2) then ss1 else let val Simpset ({rules = rules1, prems = prems1, depth = depth1}, {congs = (congs1, weak1), procs = procs1, mk_rews, term_ord, subgoal_tac, loop_tacs = loop_tacs1, solvers = (unsafe_solvers1, solvers1)}) = ss1; val Simpset ({rules = rules2, prems = prems2, depth = depth2}, {congs = (congs2, weak2), procs = procs2, mk_rews = _, term_ord = _, subgoal_tac = _, loop_tacs = loop_tacs2, solvers = (unsafe_solvers2, solvers2)}) = ss2; val rules' = Net.merge eq_rrule (rules1, rules2); val prems' = Thm.merge_thms (prems1, prems2); val depth' = if #1 depth1 < #1 depth2 then depth2 else depth1; val congs' = Congtab.merge (K true) (congs1, congs2); val weak' = merge (op =) (weak1, weak2); val procs' = Net.merge eq_proc (procs1, procs2); val loop_tacs' = AList.merge (op =) (K true) (loop_tacs1, loop_tacs2); val unsafe_solvers' = merge eq_solver (unsafe_solvers1, unsafe_solvers2); val solvers' = merge eq_solver (solvers1, solvers2); in make_simpset ((rules', prems', depth'), ((congs', weak'), procs', mk_rews, term_ord, subgoal_tac, loop_tacs', (unsafe_solvers', solvers'))) end; (** context data **) structure Simpset = Generic_Data ( type T = simpset; val empty = empty_ss; val merge = merge_ss; ); val simpset_of = Simpset.get o Context.Proof; fun map_simpset f = Context.proof_map (Simpset.map f); fun map_simpset1 f = map_simpset (fn Simpset (ss1, ss2) => Simpset (map_ss1 f ss1, ss2)); fun map_simpset2 f = map_simpset (fn Simpset (ss1, ss2) => Simpset (ss1, map_ss2 f ss2)); fun put_simpset ss = map_simpset (K ss); fun simpset_map ctxt f ss = ctxt |> put_simpset ss |> f |> simpset_of; val empty_simpset = put_simpset empty_ss; fun map_theory_simpset f thy = let val ctxt' = f (Proof_Context.init_global thy); val thy' = Proof_Context.theory_of ctxt'; in Context.theory_map (Simpset.map (K (simpset_of ctxt'))) thy' end; fun map_ss f = Context.mapping (map_theory_simpset (f o Context_Position.not_really)) f; val clear_simpset = map_simpset (fn Simpset ({depth, ...}, {mk_rews, term_ord, subgoal_tac, solvers, ...}) => init_ss depth mk_rews term_ord subgoal_tac solvers); (* accessors for tactis *) fun subgoal_tac ctxt = (#subgoal_tac o internal_ss o simpset_of) ctxt ctxt; fun loop_tac ctxt = FIRST' (map (fn (_, tac) => tac ctxt) (rev ((#loop_tacs o internal_ss o simpset_of) ctxt))); val solvers = #solvers o internal_ss o simpset_of (* simp depth *) (* The simp_depth_limit is meant to abort infinite recursion of the simplifier early but should not terminate "normal" executions. As of 2017, 25 would suffice; 40 builds in a safety margin. *) val simp_depth_limit = Config.declare_int ("simp_depth_limit", \<^here>) (K 40); val simp_trace_depth_limit = Config.declare_int ("simp_trace_depth_limit", \<^here>) (K 1); fun inc_simp_depth ctxt = ctxt |> map_simpset1 (fn (rules, prems, (depth, exceeded)) => (rules, prems, (depth + 1, if depth = Config.get ctxt simp_trace_depth_limit then Unsynchronized.ref false else exceeded))); fun simp_depth ctxt = let val Simpset ({depth = (depth, _), ...}, _) = simpset_of ctxt in depth end; (* diagnostics *) exception SIMPLIFIER of string * thm list; val simp_debug = Config.declare_bool ("simp_debug", \<^here>) (K false); val simp_trace = Config.declare_bool ("simp_trace", \<^here>) (K false); fun cond_warning ctxt msg = if Context_Position.is_really_visible ctxt then warning (msg ()) else (); fun cond_tracing' ctxt flag msg = if Config.get ctxt flag then let val Simpset ({depth = (depth, exceeded), ...}, _) = simpset_of ctxt; val depth_limit = Config.get ctxt simp_trace_depth_limit; in if depth > depth_limit then if ! exceeded then () else (tracing "simp_trace_depth_limit exceeded!"; exceeded := true) else (tracing (enclose "[" "]" (string_of_int depth) ^ msg ()); exceeded := false) end else (); fun cond_tracing ctxt = cond_tracing' ctxt simp_trace; fun print_term ctxt s t = s ^ "\n" ^ Syntax.string_of_term ctxt t; fun print_thm ctxt s (name, th) = print_term ctxt (if name = "" then s else s ^ " " ^ quote name ^ ":") (Thm.full_prop_of th); (** simpset operations **) (* prems *) fun prems_of ctxt = let val Simpset ({prems, ...}, _) = simpset_of ctxt in prems end; fun add_prems ths = map_simpset1 (fn (rules, prems, depth) => (rules, ths @ prems, depth)); (* maintain simp rules *) fun del_rrule loud (rrule as {thm, elhs, ...}) ctxt = ctxt |> map_simpset1 (fn (rules, prems, depth) => (Net.delete_term eq_rrule (Thm.term_of elhs, rrule) rules, prems, depth)) handle Net.DELETE => (if not loud then () else cond_warning ctxt (fn () => print_thm ctxt "Rewrite rule not in simpset:" ("", thm)); ctxt); fun insert_rrule (rrule as {thm, name, ...}) ctxt = (cond_tracing ctxt (fn () => print_thm ctxt "Adding rewrite rule" (name, thm)); ctxt |> map_simpset1 (fn (rules, prems, depth) => let val rrule2 as {elhs, ...} = mk_rrule2 rrule; val rules' = Net.insert_term eq_rrule (Thm.term_of elhs, trim_context_rrule rrule2) rules; in (rules', prems, depth) end) handle Net.INSERT => (cond_warning ctxt (fn () => print_thm ctxt "Ignoring duplicate rewrite rule:" ("", thm)); ctxt)); val vars_set = Vars.build o Vars.add_vars; local fun vperm (Var _, Var _) = true | vperm (Abs (_, _, s), Abs (_, _, t)) = vperm (s, t) | vperm (t1 $ t2, u1 $ u2) = vperm (t1, u1) andalso vperm (t2, u2) | vperm (t, u) = (t = u); fun var_perm (t, u) = vperm (t, u) andalso Vars.eq_set (apply2 vars_set (t, u)); in fun decomp_simp thm = let val prop = Thm.prop_of thm; val prems = Logic.strip_imp_prems prop; val concl = Drule.strip_imp_concl (Thm.cprop_of thm); val (lhs, rhs) = Thm.dest_equals concl handle TERM _ => raise SIMPLIFIER ("Rewrite rule not a meta-equality", [thm]); val elhs = Thm.dest_arg (Thm.cprop_of (Thm.eta_conversion lhs)); val erhs = Envir.eta_contract (Thm.term_of rhs); val perm = var_perm (Thm.term_of elhs, erhs) andalso not (Thm.term_of elhs aconv erhs) andalso not (is_Var (Thm.term_of elhs)); in (prems, Thm.term_of lhs, elhs, Thm.term_of rhs, perm) end; end; fun decomp_simp' thm = let val (_, lhs, _, rhs, _) = decomp_simp thm in if Thm.nprems_of thm > 0 then raise SIMPLIFIER ("Bad conditional rewrite rule", [thm]) else (lhs, rhs) end; fun mk_eq_True ctxt (thm, name) = let val Simpset (_, {mk_rews = {mk_eq_True, ...}, ...}) = simpset_of ctxt in (case mk_eq_True ctxt thm of NONE => [] | SOME eq_True => let val (_, lhs, elhs, _, _) = decomp_simp eq_True; in [{thm = eq_True, name = name, lhs = lhs, elhs = elhs, perm = false}] end) end; (*create the rewrite rule and possibly also the eq_True variant, in case there are extra vars on the rhs*) fun rrule_eq_True ctxt thm name lhs elhs rhs thm2 = let val rrule = {thm = thm, name = name, lhs = lhs, elhs = elhs, perm = false} in if rewrite_rule_extra_vars [] lhs rhs then mk_eq_True ctxt (thm2, name) @ [rrule] else [rrule] end; fun mk_rrule ctxt (thm, name) = let val (prems, lhs, elhs, rhs, perm) = decomp_simp thm in if perm then [{thm = thm, name = name, lhs = lhs, elhs = elhs, perm = true}] else (*weak test for loops*) if rewrite_rule_extra_vars prems lhs rhs orelse is_Var (Thm.term_of elhs) then mk_eq_True ctxt (thm, name) else rrule_eq_True ctxt thm name lhs elhs rhs thm end |> map (fn {thm, name, lhs, elhs, perm} => {thm = Thm.trim_context thm, name = name, lhs = lhs, elhs = Thm.trim_context_cterm elhs, perm = perm}); fun orient_rrule ctxt (thm, name) = let val (prems, lhs, elhs, rhs, perm) = decomp_simp thm; val Simpset (_, {mk_rews = {reorient, mk_sym, ...}, ...}) = simpset_of ctxt; in if perm then [{thm = thm, name = name, lhs = lhs, elhs = elhs, perm = true}] else if reorient ctxt prems lhs rhs then if reorient ctxt prems rhs lhs then mk_eq_True ctxt (thm, name) else (case mk_sym ctxt thm of NONE => [] | SOME thm' => let val (_, lhs', elhs', rhs', _) = decomp_simp thm' in rrule_eq_True ctxt thm' name lhs' elhs' rhs' thm end) else rrule_eq_True ctxt thm name lhs elhs rhs thm end; fun extract_rews ctxt sym thms = let val Simpset (_, {mk_rews = {mk, ...}, ...}) = simpset_of ctxt; val mk = if sym then fn ctxt => fn th => (mk ctxt th) RL [Drule.symmetric_thm] else mk in maps (fn thm => map (rpair (Thm.get_name_hint thm)) (mk ctxt thm)) thms end; fun extract_safe_rrules ctxt thm = maps (orient_rrule ctxt) (extract_rews ctxt false [thm]); fun mk_rrules ctxt thms = let val rews = extract_rews ctxt false thms val raw_rrules = flat (map (mk_rrule ctxt) rews) in map mk_rrule2 raw_rrules end (* add/del rules explicitly *) local fun comb_simps ctxt comb mk_rrule sym thms = let val rews = extract_rews ctxt sym (map (Thm.transfer' ctxt) thms); in fold (fold comb o mk_rrule) rews ctxt end; (* This code checks if the symetric version of a rule is already in the simpset. However, the variable names in the two versions of the rule may differ. Thus the current test modulo eq_rrule is too weak to be useful and needs to be refined. fun present ctxt rules (rrule as {thm, elhs, ...}) = (Net.insert_term eq_rrule (Thm.term_of elhs, trim_context_rrule rrule) rules; false) handle Net.INSERT => (cond_warning ctxt (fn () => print_thm ctxt "Symmetric rewrite rule already in simpset:" ("", thm)); true); fun sym_present ctxt thms = let val rews = extract_rews ctxt true (map (Thm.transfer' ctxt) thms); val rrules = map mk_rrule2 (flat(map (mk_rrule ctxt) rews)) val Simpset({rules, ...},_) = simpset_of ctxt in exists (present ctxt rules) rrules end *) in fun ctxt addsimps thms = comb_simps ctxt insert_rrule (mk_rrule ctxt) false thms; fun addsymsimps ctxt thms = comb_simps ctxt insert_rrule (mk_rrule ctxt) true thms; fun ctxt delsimps thms = comb_simps ctxt (del_rrule true) (map mk_rrule2 o mk_rrule ctxt) false thms; fun delsimps_quiet ctxt thms = comb_simps ctxt (del_rrule false) (map mk_rrule2 o mk_rrule ctxt) false thms; fun add_simp thm ctxt = ctxt addsimps [thm]; (* with check for presence of symmetric version: if sym_present ctxt [thm] then (cond_warning ctxt (fn () => print_thm ctxt "Ignoring rewrite rule:" ("", thm)); ctxt) else ctxt addsimps [thm]; *) fun del_simp thm ctxt = ctxt delsimps [thm]; fun flip_simp thm ctxt = addsymsimps (delsimps_quiet ctxt [thm]) [thm]; end; fun init_simpset thms ctxt = ctxt |> Context_Position.set_visible false |> empty_simpset |> fold add_simp thms |> Context_Position.restore_visible ctxt; (* congs *) local fun is_full_cong_prems [] [] = true | is_full_cong_prems [] _ = false | is_full_cong_prems (p :: prems) varpairs = (case Logic.strip_assums_concl p of Const ("Pure.eq", _) $ lhs $ rhs => let val (x, xs) = strip_comb lhs and (y, ys) = strip_comb rhs in is_Var x andalso forall is_Bound xs andalso not (has_duplicates (op =) xs) andalso xs = ys andalso member (op =) varpairs (x, y) andalso is_full_cong_prems prems (remove (op =) (x, y) varpairs) end | _ => false); fun is_full_cong thm = let val prems = Thm.prems_of thm and concl = Thm.concl_of thm; val (lhs, rhs) = Logic.dest_equals concl; val (f, xs) = strip_comb lhs and (g, ys) = strip_comb rhs; in f = g andalso not (has_duplicates (op =) (xs @ ys)) andalso length xs = length ys andalso is_full_cong_prems prems (xs ~~ ys) end; fun mk_cong ctxt = let val Simpset (_, {mk_rews = {mk_cong = f, ...}, ...}) = simpset_of ctxt in f ctxt end; in fun add_eqcong thm ctxt = ctxt |> map_simpset2 (fn (congs, procs, mk_rews, term_ord, subgoal_tac, loop_tacs, solvers) => let val (lhs, _) = Logic.dest_equals (Thm.concl_of thm) handle TERM _ => raise SIMPLIFIER ("Congruence not a meta-equality", [thm]); (*val lhs = Envir.eta_contract lhs;*) val a = the (cong_name (head_of lhs)) handle Option.Option => raise SIMPLIFIER ("Congruence must start with a constant or free variable", [thm]); val (xs, weak) = congs; val xs' = Congtab.update (a, Thm.trim_context thm) xs; val weak' = if is_full_cong thm then weak else a :: weak; in ((xs', weak'), procs, mk_rews, term_ord, subgoal_tac, loop_tacs, solvers) end); fun del_eqcong thm ctxt = ctxt |> map_simpset2 (fn (congs, procs, mk_rews, term_ord, subgoal_tac, loop_tacs, solvers) => let val (lhs, _) = Logic.dest_equals (Thm.concl_of thm) handle TERM _ => raise SIMPLIFIER ("Congruence not a meta-equality", [thm]); (*val lhs = Envir.eta_contract lhs;*) val a = the (cong_name (head_of lhs)) handle Option.Option => raise SIMPLIFIER ("Congruence must start with a constant", [thm]); val (xs, _) = congs; val xs' = Congtab.delete_safe a xs; val weak' = Congtab.fold (fn (a, th) => if is_full_cong th then I else insert (op =) a) xs' []; in ((xs', weak'), procs, mk_rews, term_ord, subgoal_tac, loop_tacs, solvers) end); fun add_cong thm ctxt = add_eqcong (mk_cong ctxt thm) ctxt; fun del_cong thm ctxt = del_eqcong (mk_cong ctxt thm) ctxt; end; (* simprocs *) datatype simproc = Simproc of {name: string, lhss: term list, proc: (Proof.context -> cterm -> thm option) Morphism.entity, stamp: stamp}; fun eq_simproc (Simproc {stamp = stamp1, ...}, Simproc {stamp = stamp2, ...}) = stamp1 = stamp2; fun cert_simproc thy name {lhss, proc} = Simproc {name = name, lhss = map (Sign.cert_term thy) lhss, proc = proc, stamp = stamp ()}; fun transform_simproc phi (Simproc {name, lhss, proc, stamp}) = Simproc {name = name, lhss = map (Morphism.term phi) lhss, - proc = Morphism.transform_reset_context phi proc, + proc = Morphism.transform phi proc, + stamp = stamp}; + +fun trim_context_simproc (Simproc {name, lhss, proc, stamp}) = + Simproc + {name = name, + lhss = lhss, + proc = Morphism.entity_reset_context proc, stamp = stamp}; local fun add_proc (proc as Proc {name, lhs, ...}) ctxt = (cond_tracing ctxt (fn () => print_term ctxt ("Adding simplification procedure " ^ quote name ^ " for") lhs); ctxt |> map_simpset2 (fn (congs, procs, mk_rews, term_ord, subgoal_tac, loop_tacs, solvers) => (congs, Net.insert_term eq_proc (lhs, proc) procs, mk_rews, term_ord, subgoal_tac, loop_tacs, solvers)) handle Net.INSERT => (cond_warning ctxt (fn () => "Ignoring duplicate simplification procedure " ^ quote name); ctxt)); fun del_proc (proc as Proc {name, lhs, ...}) ctxt = ctxt |> map_simpset2 (fn (congs, procs, mk_rews, term_ord, subgoal_tac, loop_tacs, solvers) => (congs, Net.delete_term eq_proc (lhs, proc) procs, mk_rews, term_ord, subgoal_tac, loop_tacs, solvers)) handle Net.DELETE => (cond_warning ctxt (fn () => "Simplification procedure " ^ quote name ^ " not in simpset"); ctxt); fun prep_procs (Simproc {name, lhss, proc, stamp}) = - lhss |> map (fn lhs => Proc {name = name, lhs = lhs, proc = Morphism.form proc, stamp = stamp}); + lhss |> map (fn lhs => Proc {name = name, lhs = lhs, proc = proc, stamp = stamp}); in fun ctxt addsimprocs ps = fold (fold add_proc o prep_procs) ps ctxt; fun ctxt delsimprocs ps = fold (fold del_proc o prep_procs) ps ctxt; end; (* mk_rews *) local fun map_mk_rews f = map_simpset2 (fn (congs, procs, mk_rews, term_ord, subgoal_tac, loop_tacs, solvers) => let val {mk, mk_cong, mk_sym, mk_eq_True, reorient} = mk_rews; val (mk', mk_cong', mk_sym', mk_eq_True', reorient') = f (mk, mk_cong, mk_sym, mk_eq_True, reorient); val mk_rews' = {mk = mk', mk_cong = mk_cong', mk_sym = mk_sym', mk_eq_True = mk_eq_True', reorient = reorient'}; in (congs, procs, mk_rews', term_ord, subgoal_tac, loop_tacs, solvers) end); in fun mksimps ctxt = let val Simpset (_, {mk_rews = {mk, ...}, ...}) = simpset_of ctxt in mk ctxt end; fun set_mksimps mk = map_mk_rews (fn (_, mk_cong, mk_sym, mk_eq_True, reorient) => (mk, mk_cong, mk_sym, mk_eq_True, reorient)); fun set_mkcong mk_cong = map_mk_rews (fn (mk, _, mk_sym, mk_eq_True, reorient) => (mk, mk_cong, mk_sym, mk_eq_True, reorient)); fun set_mksym mk_sym = map_mk_rews (fn (mk, mk_cong, _, mk_eq_True, reorient) => (mk, mk_cong, mk_sym, mk_eq_True, reorient)); fun set_mkeqTrue mk_eq_True = map_mk_rews (fn (mk, mk_cong, mk_sym, _, reorient) => (mk, mk_cong, mk_sym, mk_eq_True, reorient)); fun set_reorient reorient = map_mk_rews (fn (mk, mk_cong, mk_sym, mk_eq_True, _) => (mk, mk_cong, mk_sym, mk_eq_True, reorient)); end; (* term_ord *) fun set_term_ord term_ord = map_simpset2 (fn (congs, procs, mk_rews, _, subgoal_tac, loop_tacs, solvers) => (congs, procs, mk_rews, term_ord, subgoal_tac, loop_tacs, solvers)); (* tactics *) fun set_subgoaler subgoal_tac = map_simpset2 (fn (congs, procs, mk_rews, term_ord, _, loop_tacs, solvers) => (congs, procs, mk_rews, term_ord, subgoal_tac, loop_tacs, solvers)); fun ctxt setloop tac = ctxt |> map_simpset2 (fn (congs, procs, mk_rews, term_ord, subgoal_tac, _, solvers) => (congs, procs, mk_rews, term_ord, subgoal_tac, [("", tac)], solvers)); fun ctxt addloop (name, tac) = ctxt |> map_simpset2 (fn (congs, procs, mk_rews, term_ord, subgoal_tac, loop_tacs, solvers) => (congs, procs, mk_rews, term_ord, subgoal_tac, AList.update (op =) (name, tac) loop_tacs, solvers)); fun ctxt delloop name = ctxt |> map_simpset2 (fn (congs, procs, mk_rews, term_ord, subgoal_tac, loop_tacs, solvers) => (congs, procs, mk_rews, term_ord, subgoal_tac, (if AList.defined (op =) loop_tacs name then () else cond_warning ctxt (fn () => "No such looper in simpset: " ^ quote name); AList.delete (op =) name loop_tacs), solvers)); fun ctxt setSSolver solver = ctxt |> map_simpset2 (fn (congs, procs, mk_rews, term_ord, subgoal_tac, loop_tacs, (unsafe_solvers, _)) => (congs, procs, mk_rews, term_ord, subgoal_tac, loop_tacs, (unsafe_solvers, [solver]))); fun ctxt addSSolver solver = ctxt |> map_simpset2 (fn (congs, procs, mk_rews, term_ord, subgoal_tac, loop_tacs, (unsafe_solvers, solvers)) => (congs, procs, mk_rews, term_ord, subgoal_tac, loop_tacs, (unsafe_solvers, insert eq_solver solver solvers))); fun ctxt setSolver solver = ctxt |> map_simpset2 (fn (congs, procs, mk_rews, term_ord, subgoal_tac, loop_tacs, (_, solvers)) => (congs, procs, mk_rews, term_ord, subgoal_tac, loop_tacs, ([solver], solvers))); fun ctxt addSolver solver = ctxt |> map_simpset2 (fn (congs, procs, mk_rews, term_ord, subgoal_tac, loop_tacs, (unsafe_solvers, solvers)) => (congs, procs, mk_rews, term_ord, subgoal_tac, loop_tacs, (insert eq_solver solver unsafe_solvers, solvers))); fun set_solvers solvers = map_simpset2 (fn (congs, procs, mk_rews, term_ord, subgoal_tac, loop_tacs, _) => (congs, procs, mk_rews, term_ord, subgoal_tac, loop_tacs, (solvers, solvers))); (* trace operations *) type trace_ops = {trace_invoke: {depth: int, term: term} -> Proof.context -> Proof.context, trace_apply: {unconditional: bool, term: term, thm: thm, rrule: rrule} -> Proof.context -> (Proof.context -> (thm * term) option) -> (thm * term) option}; structure Trace_Ops = Theory_Data ( type T = trace_ops; val empty: T = {trace_invoke = fn _ => fn ctxt => ctxt, trace_apply = fn _ => fn ctxt => fn cont => cont ctxt}; fun merge (trace_ops, _) = trace_ops; ); val set_trace_ops = Trace_Ops.put; val trace_ops = Trace_Ops.get o Proof_Context.theory_of; fun trace_invoke args ctxt = #trace_invoke (trace_ops ctxt) args ctxt; fun trace_apply args ctxt = #trace_apply (trace_ops ctxt) args ctxt; (** rewriting **) (* Uses conversions, see: L C Paulson, A higher-order implementation of rewriting, Science of Computer Programming 3 (1983), pages 119-149. *) fun check_conv ctxt msg thm thm' = let val thm'' = Thm.transitive thm thm' handle THM _ => let val nthm' = Thm.transitive (Thm.symmetric (Drule.beta_eta_conversion (Thm.lhs_of thm'))) thm' in Thm.transitive thm nthm' handle THM _ => let val nthm = Thm.transitive thm (Drule.beta_eta_conversion (Thm.rhs_of thm)) in Thm.transitive nthm nthm' end end val _ = if msg then cond_tracing ctxt (fn () => print_thm ctxt "SUCCEEDED" ("", thm')) else (); in SOME thm'' end handle THM _ => let val _ $ _ $ prop0 = Thm.prop_of thm; val _ = cond_tracing ctxt (fn () => print_thm ctxt "Proved wrong theorem (bad subgoaler?)" ("", thm') ^ "\n" ^ print_term ctxt "Should have proved:" prop0); in NONE end; (* mk_procrule *) fun mk_procrule ctxt thm = let val (prems, lhs, elhs, rhs, _) = decomp_simp thm val thm' = Thm.close_derivation \<^here> thm; in if rewrite_rule_extra_vars prems lhs rhs then (cond_warning ctxt (fn () => print_thm ctxt "Extra vars on rhs:" ("", thm)); []) else [mk_rrule2 {thm = thm', name = "", lhs = lhs, elhs = elhs, perm = false}] end; (* rewritec: conversion to apply the meta simpset to a term *) (*Since the rewriting strategy is bottom-up, we avoid re-normalizing already normalized terms by carrying around the rhs of the rewrite rule just applied. This is called the `skeleton'. It is decomposed in parallel with the term. Once a Var is encountered, the corresponding term is already in normal form. skel0 is a dummy skeleton that is to enforce complete normalization.*) val skel0 = Bound 0; (*Use rhs as skeleton only if the lhs does not contain unnormalized bits. The latter may happen iff there are weak congruence rules for constants in the lhs.*) fun uncond_skel ((_, weak), (lhs, rhs)) = if null weak then rhs (*optimization*) else if exists_subterm (fn Const (a, _) => member (op =) weak (true, a) | Free (a, _) => member (op =) weak (false, a) | _ => false) lhs then skel0 else rhs; (*Behaves like unconditional rule if rhs does not contain vars not in the lhs. Otherwise those vars may become instantiated with unnormalized terms while the premises are solved.*) fun cond_skel (args as (_, (lhs, rhs))) = if Vars.subset (vars_set rhs, vars_set lhs) then uncond_skel args else skel0; (* Rewriting -- we try in order: (1) beta reduction (2) unconditional rewrite rules (3) conditional rewrite rules (4) simplification procedures IMPORTANT: rewrite rules must not introduce new Vars or TVars! *) fun rewritec (prover, maxt) ctxt t = let val thy = Proof_Context.theory_of ctxt; val Simpset ({rules, ...}, {congs, procs, term_ord, ...}) = simpset_of ctxt; val eta_thm = Thm.eta_conversion t; val eta_t' = Thm.rhs_of eta_thm; val eta_t = Thm.term_of eta_t'; fun rew rrule = let val {thm = thm0, name, lhs, elhs = elhs0, extra, fo, perm} = rrule; val thm = Thm.transfer thy thm0; val elhs = Thm.transfer_cterm thy elhs0; val prop = Thm.prop_of thm; val (rthm, elhs') = if maxt = ~1 orelse not extra then (thm, elhs) else (Thm.incr_indexes (maxt + 1) thm, Thm.incr_indexes_cterm (maxt + 1) elhs); val insts = if fo then Thm.first_order_match (elhs', eta_t') else Thm.match (elhs', eta_t'); val thm' = Thm.instantiate insts (Thm.rename_boundvars lhs eta_t rthm); val prop' = Thm.prop_of thm'; val unconditional = Logic.no_prems prop'; val (lhs', rhs') = Logic.dest_equals (Logic.strip_imp_concl prop'); val trace_args = {unconditional = unconditional, term = eta_t, thm = thm', rrule = rrule}; in if perm andalso is_greater_equal (term_ord (rhs', lhs')) then (cond_tracing ctxt (fn () => print_thm ctxt "Cannot apply permutative rewrite rule" (name, thm) ^ "\n" ^ print_thm ctxt "Term does not become smaller:" ("", thm')); NONE) else (cond_tracing ctxt (fn () => print_thm ctxt "Applying instance of rewrite rule" (name, thm)); if unconditional then (cond_tracing ctxt (fn () => print_thm ctxt "Rewriting:" ("", thm')); trace_apply trace_args ctxt (fn ctxt' => let val lr = Logic.dest_equals prop; val SOME thm'' = check_conv ctxt' false eta_thm thm'; in SOME (thm'', uncond_skel (congs, lr)) end)) else (cond_tracing ctxt (fn () => print_thm ctxt "Trying to rewrite:" ("", thm')); if simp_depth ctxt > Config.get ctxt simp_depth_limit then (cond_tracing ctxt (fn () => "simp_depth_limit exceeded - giving up"); NONE) else trace_apply trace_args ctxt (fn ctxt' => (case prover ctxt' thm' of NONE => (cond_tracing ctxt' (fn () => print_thm ctxt' "FAILED" ("", thm')); NONE) | SOME thm2 => (case check_conv ctxt' true eta_thm thm2 of NONE => NONE | SOME thm2' => let val concl = Logic.strip_imp_concl prop; val lr = Logic.dest_equals concl; in SOME (thm2', cond_skel (congs, lr)) end))))) end; fun rews [] = NONE | rews (rrule :: rrules) = let val opt = rew rrule handle Pattern.MATCH => NONE in (case opt of NONE => rews rrules | some => some) end; fun sort_rrules rrs = let fun is_simple ({thm, ...}: rrule) = (case Thm.prop_of thm of Const ("Pure.eq", _) $ _ $ _ => true | _ => false); fun sort [] (re1, re2) = re1 @ re2 | sort (rr :: rrs) (re1, re2) = if is_simple rr then sort rrs (rr :: re1, re2) else sort rrs (re1, rr :: re2); in sort rrs ([], []) end; fun proc_rews [] = NONE | proc_rews (Proc {name, proc, lhs, ...} :: ps) = if Pattern.matches (Proof_Context.theory_of ctxt) (lhs, Thm.term_of t) then (cond_tracing' ctxt simp_debug (fn () => print_term ctxt ("Trying procedure " ^ quote name ^ " on:") eta_t); - (case proc ctxt eta_t' of + (case Morphism.form_context' ctxt proc eta_t' of NONE => (cond_tracing' ctxt simp_debug (fn () => "FAILED"); proc_rews ps) | SOME raw_thm => (cond_tracing ctxt (fn () => print_thm ctxt ("Procedure " ^ quote name ^ " produced rewrite rule:") ("", raw_thm)); (case rews (mk_procrule ctxt raw_thm) of NONE => (cond_tracing ctxt (fn () => print_term ctxt ("IGNORED result of simproc " ^ quote name ^ " -- does not match") (Thm.term_of t)); proc_rews ps) | some => some)))) else proc_rews ps; in (case eta_t of Abs _ $ _ => SOME (Thm.transitive eta_thm (Thm.beta_conversion false eta_t'), skel0) | _ => (case rews (sort_rrules (Net.match_term rules eta_t)) of NONE => proc_rews (Net.match_term procs eta_t) | some => some)) end; (* conversion to apply a congruence rule to a term *) fun congc prover ctxt maxt cong t = let val rthm = Thm.incr_indexes (maxt + 1) cong; val rlhs = fst (Thm.dest_equals (Drule.strip_imp_concl (Thm.cprop_of rthm))); val insts = Thm.match (rlhs, t) (* Thm.match can raise Pattern.MATCH; is handled when congc is called *) val thm' = Thm.instantiate insts (Thm.rename_boundvars (Thm.term_of rlhs) (Thm.term_of t) rthm); val _ = cond_tracing ctxt (fn () => print_thm ctxt "Applying congruence rule:" ("", thm')); fun err (msg, thm) = (cond_tracing ctxt (fn () => print_thm ctxt msg ("", thm)); NONE); in (case prover thm' of NONE => err ("Congruence proof failed. Could not prove", thm') | SOME thm2 => (case check_conv ctxt true (Drule.beta_eta_conversion t) thm2 of NONE => err ("Congruence proof failed. Should not have proved", thm2) | SOME thm2' => if op aconv (apply2 Thm.term_of (Thm.dest_equals (Thm.cprop_of thm2'))) then NONE else SOME thm2')) end; val vA = (("A", 0), propT); val vB = (("B", 0), propT); val vC = (("C", 0), propT); fun transitive1 NONE NONE = NONE | transitive1 (SOME thm1) NONE = SOME thm1 | transitive1 NONE (SOME thm2) = SOME thm2 | transitive1 (SOME thm1) (SOME thm2) = SOME (Thm.transitive thm1 thm2); fun transitive2 thm = transitive1 (SOME thm); fun transitive3 thm = transitive1 thm o SOME; fun bottomc ((simprem, useprem, mutsimp), prover, maxidx) = let fun botc skel ctxt t = if is_Var skel then NONE else (case subc skel ctxt t of some as SOME thm1 => (case rewritec (prover, maxidx) ctxt (Thm.rhs_of thm1) of SOME (thm2, skel2) => transitive2 (Thm.transitive thm1 thm2) (botc skel2 ctxt (Thm.rhs_of thm2)) | NONE => some) | NONE => (case rewritec (prover, maxidx) ctxt t of SOME (thm2, skel2) => transitive2 thm2 (botc skel2 ctxt (Thm.rhs_of thm2)) | NONE => NONE)) and try_botc ctxt t = (case botc skel0 ctxt t of SOME trec1 => trec1 | NONE => Thm.reflexive t) and subc skel ctxt t0 = let val Simpset (_, {congs, ...}) = simpset_of ctxt in (case Thm.term_of t0 of Abs (a, _, _) => let val ((v, t'), ctxt') = Variable.dest_abs_cterm t0 ctxt; val skel' = (case skel of Abs (_, _, sk) => sk | _ => skel0); in (case botc skel' ctxt' t' of SOME thm => SOME (Thm.abstract_rule a v thm) | NONE => NONE) end | t $ _ => (case t of Const ("Pure.imp", _) $ _ => impc t0 ctxt | Abs _ => let val thm = Thm.beta_conversion false t0 in (case subc skel0 ctxt (Thm.rhs_of thm) of NONE => SOME thm | SOME thm' => SOME (Thm.transitive thm thm')) end | _ => let fun appc () = let val (tskel, uskel) = (case skel of tskel $ uskel => (tskel, uskel) | _ => (skel0, skel0)); val (ct, cu) = Thm.dest_comb t0; in (case botc tskel ctxt ct of SOME thm1 => (case botc uskel ctxt cu of SOME thm2 => SOME (Thm.combination thm1 thm2) | NONE => SOME (Thm.combination thm1 (Thm.reflexive cu))) | NONE => (case botc uskel ctxt cu of SOME thm1 => SOME (Thm.combination (Thm.reflexive ct) thm1) | NONE => NONE)) end; val (h, ts) = strip_comb t; in (case cong_name h of SOME a => (case Congtab.lookup (fst congs) a of NONE => appc () | SOME cong => (*post processing: some partial applications h t1 ... tj, j <= length ts, may be a redex. Example: map (\x. x) = (\xs. xs) wrt map_cong*) (let val thm = congc (prover ctxt) ctxt maxidx cong t0; val t = the_default t0 (Option.map Thm.rhs_of thm); val (cl, cr) = Thm.dest_comb t val dVar = Var(("", 0), dummyT) val skel = list_comb (h, replicate (length ts) dVar) in (case botc skel ctxt cl of NONE => thm | SOME thm' => transitive3 thm (Thm.combination thm' (Thm.reflexive cr))) end handle Pattern.MATCH => appc ())) | _ => appc ()) end) | _ => NONE) end and impc ct ctxt = if mutsimp then mut_impc0 [] ct [] [] ctxt else nonmut_impc ct ctxt and rules_of_prem prem ctxt = if maxidx_of_term (Thm.term_of prem) <> ~1 then (cond_tracing ctxt (fn () => print_term ctxt "Cannot add premise as rewrite rule because it contains (type) unknowns:" (Thm.term_of prem)); (([], NONE), ctxt)) else let val (asm, ctxt') = Thm.assume_hyps prem ctxt in ((extract_safe_rrules ctxt' asm, SOME asm), ctxt') end and add_rrules (rrss, asms) ctxt = (fold o fold) insert_rrule rrss ctxt |> add_prems (map_filter I asms) and disch r prem eq = let val (lhs, rhs) = Thm.dest_equals (Thm.cprop_of eq); val eq' = Thm.implies_elim (Thm.instantiate (TVars.empty, Vars.make3 (vA, prem) (vB, lhs) (vC, rhs)) Drule.imp_cong) (Thm.implies_intr prem eq); in if not r then eq' else let val (prem', concl) = Thm.dest_implies lhs; val (prem'', _) = Thm.dest_implies rhs; in Thm.transitive (Thm.transitive (Thm.instantiate (TVars.empty, Vars.make3 (vA, prem') (vB, prem) (vC, concl)) Drule.swap_prems_eq) eq') (Thm.instantiate (TVars.empty, Vars.make3 (vA, prem) (vB, prem'') (vC, concl)) Drule.swap_prems_eq) end end and rebuild [] _ _ _ _ eq = eq | rebuild (prem :: prems) concl (_ :: rrss) (_ :: asms) ctxt eq = let val ctxt' = add_rrules (rev rrss, rev asms) ctxt; val concl' = Drule.mk_implies (prem, the_default concl (Option.map Thm.rhs_of eq)); val dprem = Option.map (disch false prem); in (case rewritec (prover, maxidx) ctxt' concl' of NONE => rebuild prems concl' rrss asms ctxt (dprem eq) | SOME (eq', _) => transitive2 (fold (disch false) prems (the (transitive3 (dprem eq) eq'))) (mut_impc0 (rev prems) (Thm.rhs_of eq') (rev rrss) (rev asms) ctxt)) end and mut_impc0 prems concl rrss asms ctxt = let val prems' = strip_imp_prems concl; val ((rrss', asms'), ctxt') = fold_map rules_of_prem prems' ctxt |>> split_list; in mut_impc (prems @ prems') (strip_imp_concl concl) (rrss @ rrss') (asms @ asms') [] [] [] [] ctxt' ~1 ~1 end and mut_impc [] concl [] [] prems' rrss' asms' eqns ctxt changed k = transitive1 (fold (fn (eq1, prem) => fn eq2 => transitive1 eq1 (Option.map (disch false prem) eq2)) (eqns ~~ prems') NONE) (if changed > 0 then mut_impc (rev prems') concl (rev rrss') (rev asms') [] [] [] [] ctxt ~1 changed else rebuild prems' concl rrss' asms' ctxt (botc skel0 (add_rrules (rev rrss', rev asms') ctxt) concl)) | mut_impc (prem :: prems) concl (rrs :: rrss) (asm :: asms) prems' rrss' asms' eqns ctxt changed k = (case (if k = 0 then NONE else botc skel0 (add_rrules (rev rrss' @ rrss, rev asms' @ asms) ctxt) prem) of NONE => mut_impc prems concl rrss asms (prem :: prems') (rrs :: rrss') (asm :: asms') (NONE :: eqns) ctxt changed (if k = 0 then 0 else k - 1) | SOME eqn => let val prem' = Thm.rhs_of eqn; val tprems = map Thm.term_of prems; val i = 1 + fold Integer.max (map (fn p => find_index (fn q => q aconv p) tprems) (Thm.hyps_of eqn)) ~1; val ((rrs', asm'), ctxt') = rules_of_prem prem' ctxt; in mut_impc prems concl rrss asms (prem' :: prems') (rrs' :: rrss') (asm' :: asms') (SOME (fold_rev (disch true) (take i prems) (Drule.imp_cong_rule eqn (Thm.reflexive (Drule.list_implies (drop i prems, concl))))) :: eqns) ctxt' (length prems') ~1 end) (*legacy code -- only for backwards compatibility*) and nonmut_impc ct ctxt = let val (prem, conc) = Thm.dest_implies ct; val thm1 = if simprem then botc skel0 ctxt prem else NONE; val prem1 = the_default prem (Option.map Thm.rhs_of thm1); val ctxt1 = if not useprem then ctxt else let val ((rrs, asm), ctxt') = rules_of_prem prem1 ctxt in add_rrules ([rrs], [asm]) ctxt' end; in (case botc skel0 ctxt1 conc of NONE => (case thm1 of NONE => NONE | SOME thm1' => SOME (Drule.imp_cong_rule thm1' (Thm.reflexive conc))) | SOME thm2 => let val thm2' = disch false prem1 thm2 in (case thm1 of NONE => SOME thm2' | SOME thm1' => SOME (Thm.transitive (Drule.imp_cong_rule thm1' (Thm.reflexive conc)) thm2')) end) end; in try_botc end; (* Meta-rewriting: rewrites t to u and returns the theorem t \ u *) (* Parameters: mode = (simplify A, use A in simplifying B, use prems of B (if B is again a meta-impl.) to simplify A) when simplifying A \ B prover: how to solve premises in conditional rewrites and congruences *) fun rewrite_cterm mode prover raw_ctxt raw_ct = let val thy = Proof_Context.theory_of raw_ctxt; val ct = raw_ct |> Thm.transfer_cterm thy |> Thm.adjust_maxidx_cterm ~1; val maxidx = Thm.maxidx_of_cterm ct; val ctxt = raw_ctxt |> Variable.set_body true |> Context_Position.set_visible false |> inc_simp_depth |> (fn ctxt => trace_invoke {depth = simp_depth ctxt, term = Thm.term_of ct} ctxt); val _ = cond_tracing ctxt (fn () => print_term ctxt "SIMPLIFIER INVOKED ON THE FOLLOWING TERM:" (Thm.term_of ct)); in ct |> bottomc (mode, Option.map (Drule.flexflex_unique (SOME ctxt)) oo prover, maxidx) ctxt |> Thm.solve_constraints end; val simple_prover = SINGLE o (fn ctxt => ALLGOALS (resolve_tac ctxt (prems_of ctxt))); fun rewrite _ _ [] = Thm.reflexive | rewrite ctxt full thms = rewrite_cterm (full, false, false) simple_prover (init_simpset thms ctxt); fun rewrite_rule ctxt = Conv.fconv_rule o rewrite ctxt true; (*simple term rewriting -- no proof*) fun rewrite_term thy rules procs = Pattern.rewrite_term thy (map decomp_simp' rules) procs; fun rewrite_thm mode prover ctxt = Conv.fconv_rule (rewrite_cterm mode prover ctxt); (*Rewrite the subgoals of a proof state (represented by a theorem)*) fun rewrite_goals_rule ctxt thms th = Conv.fconv_rule (Conv.prems_conv ~1 (rewrite_cterm (true, true, true) simple_prover (init_simpset thms ctxt))) th; (** meta-rewriting tactics **) (*Rewrite all subgoals*) fun rewrite_goals_tac ctxt defs = PRIMITIVE (rewrite_goals_rule ctxt defs); (*Rewrite one subgoal*) fun generic_rewrite_goal_tac mode prover_tac ctxt i thm = if 0 < i andalso i <= Thm.nprems_of thm then Seq.single (Conv.gconv_rule (rewrite_cterm mode (SINGLE o prover_tac) ctxt) i thm) else Seq.empty; fun rewrite_goal_tac ctxt thms = generic_rewrite_goal_tac (true, false, false) (K no_tac) (init_simpset thms ctxt); (*Prunes all redundant parameters from the proof state by rewriting.*) fun prune_params_tac ctxt = rewrite_goals_tac ctxt [Drule.triv_forall_equality]; (* for folding definitions, handling critical pairs *) (*The depth of nesting in a term*) fun term_depth (Abs (_, _, t)) = 1 + term_depth t | term_depth (f $ t) = 1 + Int.max (term_depth f, term_depth t) | term_depth _ = 0; val lhs_of_thm = #1 o Logic.dest_equals o Thm.prop_of; (*folding should handle critical pairs! E.g. K \ Inl 0, S \ Inr (Inl 0) Returns longest lhs first to avoid folding its subexpressions.*) fun sort_lhs_depths defs = let val keylist = AList.make (term_depth o lhs_of_thm) defs val keys = sort_distinct (rev_order o int_ord) (map #2 keylist) in map (AList.find (op =) keylist) keys end; val rev_defs = sort_lhs_depths o map Thm.symmetric; fun fold_rule ctxt defs = fold (rewrite_rule ctxt) (rev_defs defs); fun fold_goals_tac ctxt defs = EVERY (map (rewrite_goals_tac ctxt) (rev_defs defs)); (* HHF normal form: \ before \, outermost \ generalized *) local fun gen_norm_hhf protect ss ctxt0 th0 = let val (ctxt, th) = Thm.join_transfer_context (ctxt0, th0); val th' = if Drule.is_norm_hhf protect (Thm.prop_of th) then th else Conv.fconv_rule (rewrite_cterm (true, false, false) (K (K NONE)) (put_simpset ss ctxt)) th; in th' |> Thm.adjust_maxidx_thm ~1 |> Variable.gen_all ctxt end; val hhf_ss = Context.the_local_context () |> init_simpset Drule.norm_hhf_eqs |> simpset_of; val hhf_protect_ss = Context.the_local_context () |> init_simpset Drule.norm_hhf_eqs |> add_eqcong Drule.protect_cong |> simpset_of; in val norm_hhf = gen_norm_hhf {protect = false} hhf_ss; val norm_hhf_protect = gen_norm_hhf {protect = true} hhf_protect_ss; end; end; structure Basic_Meta_Simplifier: BASIC_RAW_SIMPLIFIER = Raw_Simplifier; open Basic_Meta_Simplifier; diff --git a/src/Pure/simplifier.ML b/src/Pure/simplifier.ML --- a/src/Pure/simplifier.ML +++ b/src/Pure/simplifier.ML @@ -1,434 +1,433 @@ (* Title: Pure/simplifier.ML Author: Tobias Nipkow and Markus Wenzel, TU Muenchen Generic simplifier, suitable for most logics (see also raw_simplifier.ML for the actual meta-level rewriting engine). *) signature BASIC_SIMPLIFIER = sig include BASIC_RAW_SIMPLIFIER val simp_tac: Proof.context -> int -> tactic val asm_simp_tac: Proof.context -> int -> tactic val full_simp_tac: Proof.context -> int -> tactic val asm_lr_simp_tac: Proof.context -> int -> tactic val asm_full_simp_tac: Proof.context -> int -> tactic val safe_simp_tac: Proof.context -> int -> tactic val safe_asm_simp_tac: Proof.context -> int -> tactic val safe_full_simp_tac: Proof.context -> int -> tactic val safe_asm_lr_simp_tac: Proof.context -> int -> tactic val safe_asm_full_simp_tac: Proof.context -> int -> tactic val simplify: Proof.context -> thm -> thm val asm_simplify: Proof.context -> thm -> thm val full_simplify: Proof.context -> thm -> thm val asm_lr_simplify: Proof.context -> thm -> thm val asm_full_simplify: Proof.context -> thm -> thm end; signature SIMPLIFIER = sig include BASIC_SIMPLIFIER val map_ss: (Proof.context -> Proof.context) -> Context.generic -> Context.generic val attrib: (thm -> Proof.context -> Proof.context) -> attribute val simp_add: attribute val simp_del: attribute val simp_flip: attribute val cong_add: attribute val cong_del: attribute - val check_simproc: Proof.context -> xstring * Position.T -> string + val check_simproc: Proof.context -> xstring * Position.T -> string * simproc val the_simproc: Proof.context -> string -> simproc type 'a simproc_spec = {lhss: 'a list, proc: morphism -> Proof.context -> cterm -> thm option} val make_simproc: Proof.context -> string -> term simproc_spec -> simproc val define_simproc: binding -> term simproc_spec -> local_theory -> local_theory val define_simproc_cmd: binding -> string simproc_spec -> local_theory -> local_theory val pretty_simpset: bool -> Proof.context -> Pretty.T val default_mk_sym: Proof.context -> thm -> thm option val prems_of: Proof.context -> thm list val add_simp: thm -> Proof.context -> Proof.context val del_simp: thm -> Proof.context -> Proof.context val init_simpset: thm list -> Proof.context -> Proof.context val add_eqcong: thm -> Proof.context -> Proof.context val del_eqcong: thm -> Proof.context -> Proof.context val add_cong: thm -> Proof.context -> Proof.context val del_cong: thm -> Proof.context -> Proof.context val add_prems: thm list -> Proof.context -> Proof.context val mksimps: Proof.context -> thm -> thm list val set_mksimps: (Proof.context -> thm -> thm list) -> Proof.context -> Proof.context val set_mkcong: (Proof.context -> thm -> thm) -> Proof.context -> Proof.context val set_mksym: (Proof.context -> thm -> thm option) -> Proof.context -> Proof.context val set_mkeqTrue: (Proof.context -> thm -> thm option) -> Proof.context -> Proof.context val set_term_ord: term ord -> Proof.context -> Proof.context val set_subgoaler: (Proof.context -> int -> tactic) -> Proof.context -> Proof.context type trace_ops val set_trace_ops: trace_ops -> theory -> theory val rewrite: Proof.context -> conv val asm_rewrite: Proof.context -> conv val full_rewrite: Proof.context -> conv val asm_lr_rewrite: Proof.context -> conv val asm_full_rewrite: Proof.context -> conv val cong_modifiers: Method.modifier parser list val simp_modifiers': Method.modifier parser list val simp_modifiers: Method.modifier parser list val method_setup: Method.modifier parser list -> theory -> theory val unsafe_solver_tac: Proof.context -> int -> tactic val unsafe_solver: solver val safe_solver_tac: Proof.context -> int -> tactic val safe_solver: solver end; structure Simplifier: SIMPLIFIER = struct open Raw_Simplifier; (** declarations **) (* attributes *) fun attrib f = Thm.declaration_attribute (map_ss o f); val simp_add = attrib add_simp; val simp_del = attrib del_simp; val simp_flip = attrib flip_simp; val cong_add = attrib add_cong; val cong_del = attrib del_cong; (** named simprocs **) structure Simprocs = Generic_Data ( type T = simproc Name_Space.table; val empty : T = Name_Space.empty_table "simproc"; fun merge data : T = Name_Space.merge_tables data; ); (* get simprocs *) val get_simprocs = Simprocs.get o Context.Proof; -fun check_simproc ctxt = Name_Space.check (Context.Proof ctxt) (get_simprocs ctxt) #> #1; val the_simproc = Name_Space.get o get_simprocs; +fun check_simproc ctxt = Name_Space.check (Context.Proof ctxt) (get_simprocs ctxt); val _ = Theory.setup (ML_Antiquotation.value_embedded \<^binding>\simproc\ - (Args.context -- Scan.lift Parse.embedded_position - >> (fn (ctxt, name) => - "Simplifier.the_simproc ML_context " ^ ML_Syntax.print_string (check_simproc ctxt name)))); + (Args.context -- Scan.lift Parse.embedded_position >> (fn (ctxt, name) => + "Simplifier.the_simproc ML_context " ^ ML_Syntax.print_string (#1 (check_simproc ctxt name))))); (* define simprocs *) type 'a simproc_spec = {lhss: 'a list, proc: morphism -> Proof.context -> cterm -> thm option}; fun make_simproc ctxt name {lhss, proc} = let val ctxt' = fold Proof_Context.augment lhss ctxt; val lhss' = Variable.export_terms ctxt' ctxt lhss; in cert_simproc (Proof_Context.theory_of ctxt) name {lhss = lhss', proc = Morphism.entity proc} end; local fun def_simproc prep b {lhss, proc} lthy = let - val simproc = + val simproc0 = make_simproc lthy (Local_Theory.full_name lthy b) {lhss = prep lthy lhss, proc = proc}; in lthy |> Local_Theory.declaration {syntax = false, pervasive = false, pos = Binding.pos_of b} (fn phi => fn context => let val b' = Morphism.binding phi b; - val simproc' = transform_simproc phi simproc; + val simproc' = simproc0 |> transform_simproc phi |> trim_context_simproc; in context |> Simprocs.map (#2 o Name_Space.define context true (b', simproc')) |> map_ss (fn ctxt => ctxt addsimprocs [simproc']) end) end; in val define_simproc = def_simproc Syntax.check_terms; val define_simproc_cmd = def_simproc Syntax.read_terms; end; (** congruence rule to protect foundational terms of local definitions **) local fun add_foundation_cong (binding, (const, target_params)) gthy = if null target_params then gthy else let val thy = Context.theory_of gthy; val cong = list_comb (const, target_params) |> Logic.varify_global |> Thm.global_cterm_of thy |> Thm.reflexive |> Thm.close_derivation \<^here>; val cong_binding = Binding.qualify_name true binding "cong"; in gthy |> Attrib.generic_notes Thm.theoremK [((cong_binding, []), [([cong], [])])] |> #2 end; val _ = Theory.setup (Generic_Target.add_foundation_interpretation add_foundation_cong); in end; (** pretty_simpset **) fun pretty_simpset verbose ctxt = let val pretty_term = Syntax.pretty_term ctxt; val pretty_thm = Thm.pretty_thm ctxt; val pretty_thm_item = Thm.pretty_thm_item ctxt; fun pretty_simproc (name, lhss) = Pretty.block (Pretty.mark_str name :: Pretty.str ":" :: Pretty.fbrk :: Pretty.fbreaks (map (Pretty.item o single o pretty_term) lhss)); fun pretty_cong_name (const, name) = pretty_term ((if const then Const else Free) (name, dummyT)); fun pretty_cong (name, thm) = Pretty.block [pretty_cong_name name, Pretty.str ":", Pretty.brk 1, pretty_thm thm]; val {simps, procs, congs, loopers, unsafe_solvers, safe_solvers, ...} = dest_ss (simpset_of ctxt); val simprocs = Name_Space.markup_entries verbose ctxt (Name_Space.space_of_table (get_simprocs ctxt)) procs; in [Pretty.big_list "simplification rules:" (map (pretty_thm_item o #2) simps), Pretty.big_list "simplification procedures:" (map pretty_simproc simprocs), Pretty.big_list "congruences:" (map pretty_cong congs), Pretty.strs ("loopers:" :: map quote loopers), Pretty.strs ("unsafe solvers:" :: map quote unsafe_solvers), Pretty.strs ("safe solvers:" :: map quote safe_solvers)] |> Pretty.chunks end; (** simplification tactics and rules **) fun solve_all_tac solvers ctxt = let val subgoal_tac = Raw_Simplifier.subgoal_tac (Raw_Simplifier.set_solvers solvers ctxt); val solve_tac = subgoal_tac THEN_ALL_NEW (K no_tac); in DEPTH_SOLVE (solve_tac 1) end; (*NOTE: may instantiate unknowns that appear also in other subgoals*) fun generic_simp_tac safe mode ctxt = let val loop_tac = Raw_Simplifier.loop_tac ctxt; val (unsafe_solvers, solvers) = Raw_Simplifier.solvers ctxt; val solve_tac = FIRST' (map (Raw_Simplifier.solver ctxt) (rev (if safe then solvers else unsafe_solvers))); fun simp_loop_tac i = Raw_Simplifier.generic_rewrite_goal_tac mode (solve_all_tac unsafe_solvers) ctxt i THEN (solve_tac i ORELSE TRY ((loop_tac THEN_ALL_NEW simp_loop_tac) i)); in PREFER_GOAL (simp_loop_tac 1) end; local fun simp rew mode ctxt thm = let val (unsafe_solvers, _) = Raw_Simplifier.solvers ctxt; val tacf = solve_all_tac (rev unsafe_solvers); fun prover s th = Option.map #1 (Seq.pull (tacf s th)); in rew mode prover ctxt thm end; in val simp_thm = simp Raw_Simplifier.rewrite_thm; val simp_cterm = simp Raw_Simplifier.rewrite_cterm; end; (* tactics *) val simp_tac = generic_simp_tac false (false, false, false); val asm_simp_tac = generic_simp_tac false (false, true, false); val full_simp_tac = generic_simp_tac false (true, false, false); val asm_lr_simp_tac = generic_simp_tac false (true, true, false); val asm_full_simp_tac = generic_simp_tac false (true, true, true); (*not totally safe: may instantiate unknowns that appear also in other subgoals*) val safe_simp_tac = generic_simp_tac true (false, false, false); val safe_asm_simp_tac = generic_simp_tac true (false, true, false); val safe_full_simp_tac = generic_simp_tac true (true, false, false); val safe_asm_lr_simp_tac = generic_simp_tac true (true, true, false); val safe_asm_full_simp_tac = generic_simp_tac true (true, true, true); (* conversions *) val simplify = simp_thm (false, false, false); val asm_simplify = simp_thm (false, true, false); val full_simplify = simp_thm (true, false, false); val asm_lr_simplify = simp_thm (true, true, false); val asm_full_simplify = simp_thm (true, true, true); val rewrite = simp_cterm (false, false, false); val asm_rewrite = simp_cterm (false, true, false); val full_rewrite = simp_cterm (true, false, false); val asm_lr_rewrite = simp_cterm (true, true, false); val asm_full_rewrite = simp_cterm (true, true, true); (** concrete syntax of attributes **) (* add / del *) val simpN = "simp"; val flipN = "flip" val congN = "cong"; val onlyN = "only"; val no_asmN = "no_asm"; val no_asm_useN = "no_asm_use"; val no_asm_simpN = "no_asm_simp"; val asm_lrN = "asm_lr"; (* simprocs *) local val add_del = (Args.del -- Args.colon >> K (op delsimprocs) || Scan.option (Args.add -- Args.colon) >> K (op addsimprocs)) >> (fn f => fn simproc => Morphism.entity (fn phi => Thm.declaration_attribute (K (Raw_Simplifier.map_ss (fn ctxt => f (ctxt, [transform_simproc phi simproc])))))); in val simproc_att = (Args.context -- Scan.lift add_del) :|-- (fn (ctxt, decl) => - Scan.repeat1 (Scan.lift (Args.named_attribute (decl o the_simproc ctxt o check_simproc ctxt)))) + Scan.repeat1 (Scan.lift (Args.named_attribute (decl o #2 o check_simproc ctxt)))) >> (fn atts => Thm.declaration_attribute (fn th => fold (fn att => Thm.attribute_declaration (Morphism.form att) th) atts)); end; (* conversions *) local fun conv_mode x = ((Args.parens (Args.$$$ no_asmN) >> K simplify || Args.parens (Args.$$$ no_asm_simpN) >> K asm_simplify || Args.parens (Args.$$$ no_asm_useN) >> K full_simplify || Scan.succeed asm_full_simplify) |> Scan.lift) x; in val simplified = conv_mode -- Attrib.thms >> (fn (f, ths) => Thm.rule_attribute ths (fn context => f ((if null ths then I else Raw_Simplifier.clear_simpset) (Context.proof_of context) addsimps ths))); end; (* setup attributes *) val _ = Theory.setup (Attrib.setup \<^binding>\simp\ (Attrib.add_del simp_add simp_del) "declaration of Simplifier rewrite rule" #> Attrib.setup \<^binding>\cong\ (Attrib.add_del cong_add cong_del) "declaration of Simplifier congruence rule" #> Attrib.setup \<^binding>\simproc\ simproc_att "declaration of simplification procedures" #> Attrib.setup \<^binding>\simplified\ simplified "simplified rule"); (** method syntax **) val cong_modifiers = [Args.$$$ congN -- Args.colon >> K (Method.modifier cong_add \<^here>), Args.$$$ congN -- Args.add -- Args.colon >> K (Method.modifier cong_add \<^here>), Args.$$$ congN -- Args.del -- Args.colon >> K (Method.modifier cong_del \<^here>)]; val simp_modifiers = [Args.$$$ simpN -- Args.colon >> K (Method.modifier simp_add \<^here>), Args.$$$ simpN -- Args.add -- Args.colon >> K (Method.modifier simp_add \<^here>), Args.$$$ simpN -- Args.del -- Args.colon >> K (Method.modifier simp_del \<^here>), Args.$$$ simpN -- Args.$$$ flipN -- Args.colon >> K (Method.modifier simp_flip \<^here>), Args.$$$ simpN -- Args.$$$ onlyN -- Args.colon >> K {init = Raw_Simplifier.clear_simpset, attribute = simp_add, pos = \<^here>}] @ cong_modifiers; val simp_modifiers' = [Args.add -- Args.colon >> K (Method.modifier simp_add \<^here>), Args.del -- Args.colon >> K (Method.modifier simp_del \<^here>), Args.$$$ flipN -- Args.colon >> K (Method.modifier simp_flip \<^here>), Args.$$$ onlyN -- Args.colon >> K {init = Raw_Simplifier.clear_simpset, attribute = simp_add, pos = \<^here>}] @ cong_modifiers; val simp_options = (Args.parens (Args.$$$ no_asmN) >> K simp_tac || Args.parens (Args.$$$ no_asm_simpN) >> K asm_simp_tac || Args.parens (Args.$$$ no_asm_useN) >> K full_simp_tac || Args.parens (Args.$$$ asm_lrN) >> K asm_lr_simp_tac || Scan.succeed asm_full_simp_tac); fun simp_method more_mods meth = Scan.lift simp_options --| Method.sections (more_mods @ simp_modifiers') >> (fn tac => fn ctxt => METHOD (fn facts => meth ctxt tac facts)); (** setup **) fun method_setup more_mods = Method.setup \<^binding>\simp\ (simp_method more_mods (fn ctxt => fn tac => fn facts => HEADGOAL (Method.insert_tac ctxt facts THEN' (CHANGED_PROP oo tac) ctxt))) "simplification" #> Method.setup \<^binding>\simp_all\ (simp_method more_mods (fn ctxt => fn tac => fn facts => ALLGOALS (Method.insert_tac ctxt facts) THEN (CHANGED_PROP o PARALLEL_ALLGOALS o tac) ctxt)) "simplification (all goals)"; fun unsafe_solver_tac ctxt = FIRST' [resolve_tac ctxt (Drule.reflexive_thm :: Raw_Simplifier.prems_of ctxt), assume_tac ctxt]; val unsafe_solver = mk_solver "Pure unsafe" unsafe_solver_tac; (*no premature instantiation of variables during simplification*) fun safe_solver_tac ctxt = FIRST' [match_tac ctxt (Drule.reflexive_thm :: Raw_Simplifier.prems_of ctxt), eq_assume_tac]; val safe_solver = mk_solver "Pure safe" safe_solver_tac; val _ = Theory.setup (method_setup [] #> Context.theory_map (map_ss (fn ctxt => empty_simpset ctxt setSSolver safe_solver setSolver unsafe_solver |> set_subgoaler asm_simp_tac))); end; structure Basic_Simplifier: BASIC_SIMPLIFIER = Simplifier; open Basic_Simplifier;