diff --git a/src/HOL/TPTP/mash_eval.ML b/src/HOL/TPTP/mash_eval.ML --- a/src/HOL/TPTP/mash_eval.ML +++ b/src/HOL/TPTP/mash_eval.ML @@ -1,170 +1,170 @@ (* Title: HOL/TPTP/mash_eval.ML Author: Jasmin Blanchette, TU Muenchen Copyright 2012 Evaluate proof suggestions from MaSh (Machine-learning for Sledgehammer). *) signature MASH_EVAL = sig type params = Sledgehammer_Prover.params val evaluate_mash_suggestions : Proof.context -> params -> int * int option -> string option -> string list -> string -> unit end; structure MaSh_Eval : MASH_EVAL = struct open Sledgehammer_Util open Sledgehammer_Fact open Sledgehammer_MePo open Sledgehammer_MaSh open Sledgehammer_Prover open Sledgehammer_Prover_ATP open Sledgehammer_Commands open MaSh_Export val prefix = Library.prefix fun evaluate_mash_suggestions ctxt params range prob_dir_name file_names report_file_name = let val thy = Proof_Context.theory_of ctxt val zeros = [0, 0, 0, 0, 0, 0] val report_path = report_file_name |> Path.explode val _ = File.write report_path "" fun print s = File.append report_path (s ^ "\n") val {provers, max_facts, slice, type_enc, lam_trans, timeout, induction_rules, ...} = default_params thy [] val prover = hd provers val max_suggs = generous_max_suggestions (the max_facts) val inst_inducts = induction_rules = SOME Instantiate val method_of_file_name = perhaps (try (unsuffix "_suggestions")) o List.last o space_explode "/" val methods = "isar" :: map method_of_file_name file_names val lines_of = Path.explode #> try File.read_lines #> these val liness0 = map lines_of file_names val num_lines = fold (Integer.max o length) liness0 0 fun pad lines = lines @ replicate (num_lines - length lines) "" val liness' = Ctr_Sugar_Util.transpose (map pad liness0) val css = clasimpset_rule_table_of ctxt val facts = all_facts ctxt true Keyword.empty_keywords [] [] css val name_tabs = build_name_tables nickname_of_thm facts fun with_index facts s = (find_index (curry (op =) s) facts + 1, s) fun index_str (j, s) = s ^ "@" ^ string_of_int j val str_of_method = enclose " " ": " fun str_of_result method facts ({outcome, run_time, used_facts, ...} : prover_result) = let val facts = facts |> map (fst o fst) in str_of_method method ^ (if is_none outcome then "Success (" ^ ATP_Util.string_of_time run_time ^ "): " ^ (used_facts |> map (with_index facts o fst) |> sort (int_ord o apply2 fst) |> map index_str |> space_implode " ") ^ (if length facts < the max_facts then " (of " ^ string_of_int (length facts) ^ ")" else "") else "Failure: " ^ (facts |> take (the max_facts) |> tag_list 1 |> map index_str |> space_implode " ")) end fun solve_goal (j, lines) = if in_range range j andalso exists (curry (op <>) "") lines then let val get_suggs = extract_suggestions ##> (take max_suggs #> map fst) val (names, suggss0) = split_list (map get_suggs lines) val name = (case names |> filter (curry (op <>) "") |> distinct (op =) of [name] => name | names => error ("Input files out of sync: facts " ^ commas (map quote names))) val th = case find_first (fn (_, th) => nickname_of_thm th = name) facts of SOME (_, th) => th | NONE => error ("No fact called \"" ^ name) val goal = goal_of_thm (Proof_Context.theory_of ctxt) th val (_, hyp_ts, concl_t) = ATP_Util.strip_subgoal goal 1 ctxt val isar_deps = these (isar_dependencies_of name_tabs th) val suggss = isar_deps :: suggss0 val facts = facts |> filter (fn (_, th') => thm_less (th', th)) (* adapted from "mirabelle_sledgehammer.ML" *) fun set_file_name method (SOME dir) = let val prob_prefix = "goal_" ^ string_of_int j ^ "__" ^ encode_str name ^ "__" ^ method in Config.put atp_dest_dir dir #> Config.put atp_problem_prefix (prob_prefix ^ "__") #> Config.put SMT_Config.debug_files (dir ^ "/" ^ prob_prefix) end | set_file_name _ NONE = I fun prove method suggs = if null facts then (str_of_method method ^ "Skipped", 0) else let fun nickify ((_, stature), th) = ((K (encode_str (nickname_of_thm th)), stature), th) val facts = suggs |> find_suggested_facts ctxt facts - |> map (fact_of_raw_fact #> nickify) + |> map (fact_of_lazy_fact #> nickify) |> inst_inducts ? instantiate_inducts ctxt hyp_ts concl_t |> take (the max_facts) - |> map fact_of_raw_fact + |> map fact_of_lazy_fact val ctxt = ctxt |> set_file_name method prob_dir_name val res as {outcome, ...} = run_prover_for_mash ctxt params prover name facts goal val ok = if is_none outcome then 1 else 0 in (str_of_result method facts res, ok) end val ress = map2 prove methods suggss in "Goal " ^ string_of_int j ^ ": " ^ name :: map fst ress |> cat_lines |> print; map snd ress end else zeros val options = ["prover = " ^ prover, "max_facts = " ^ string_of_int (the max_facts), "slice" |> not slice ? prefix "dont_", "type_enc = " ^ the_default "smart" type_enc, "lam_trans = " ^ the_default "smart" lam_trans, "timeout = " ^ ATP_Util.string_of_time timeout, "instantiate_inducts" |> not inst_inducts ? prefix "dont_"] val _ = print " * * *"; val _ = print ("Options: " ^ commas options); val oks = Par_List.map solve_goal (tag_list 1 liness') val n = length oks fun total_of method ok = str_of_method method ^ string_of_int ok ^ " (" ^ Real.fmt (StringCvt.FIX (SOME 1)) (100.0 * Real.fromInt ok / Real.fromInt (Int.max (1, n))) ^ "%)" val oks' = if n = 0 then zeros else map Integer.sum (map_transpose I oks) in "Successes (of " ^ string_of_int n ^ " goals)" :: map2 total_of methods oks' |> cat_lines |> print end end; diff --git a/src/HOL/Tools/Sledgehammer/sledgehammer_fact.ML b/src/HOL/Tools/Sledgehammer/sledgehammer_fact.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_fact.ML +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_fact.ML @@ -1,559 +1,558 @@ (* Title: HOL/Tools/Sledgehammer/sledgehammer_fact.ML Author: Jia Meng, Cambridge University Computer Laboratory and NICTA Author: Jasmin Blanchette, TU Muenchen Sledgehammer fact handling. *) signature SLEDGEHAMMER_FACT = sig type status = ATP_Problem_Generate.status type stature = ATP_Problem_Generate.stature - type raw_fact = ((unit -> string) * stature) * thm (* TODO: rename to lazy_fact *) + type lazy_fact = ((unit -> string) * stature) * thm type fact = (string * stature) * thm type fact_override = {add : (Facts.ref * Token.src list) list, del : (Facts.ref * Token.src list) list, only : bool} val no_fact_override : fact_override val fact_of_ref : Proof.context -> Keyword.keywords -> thm list -> status Termtab.table -> Facts.ref * Token.src list -> ((string * stature) * thm) list val cartouche_thm : Proof.context -> thm -> string val is_blacklisted_or_something : string -> bool val clasimpset_rule_table_of : Proof.context -> status Termtab.table val build_name_tables : (thm -> string) -> ('a * thm) list -> string Symtab.table * string Symtab.table val fact_distinct : (term * term -> bool) -> ('a * thm) list -> ('a * thm) list val instantiate_inducts : Proof.context -> term list -> term -> (((unit -> string) * 'a) * thm) list -> (((unit -> string) * 'a) * thm) list - val fact_of_raw_fact : raw_fact -> fact + val fact_of_lazy_fact : lazy_fact -> fact val is_useful_unnamed_local_fact : Proof.context -> thm -> bool val all_facts : Proof.context -> bool -> Keyword.keywords -> thm list -> thm list -> - status Termtab.table -> raw_fact list + status Termtab.table -> lazy_fact list val nearly_all_facts : Proof.context -> bool -> fact_override -> Keyword.keywords -> - status Termtab.table -> thm list -> term list -> term -> raw_fact list - val drop_duplicate_facts : raw_fact list -> raw_fact list - + status Termtab.table -> thm list -> term list -> term -> lazy_fact list + val drop_duplicate_facts : lazy_fact list -> lazy_fact list end; structure Sledgehammer_Fact : SLEDGEHAMMER_FACT = struct open ATP_Util open ATP_Problem_Generate open Sledgehammer_Util -type raw_fact = ((unit -> string) * stature) * thm +type lazy_fact = ((unit -> string) * stature) * thm type fact = (string * stature) * thm type fact_override = {add : (Facts.ref * Token.src list) list, del : (Facts.ref * Token.src list) list, only : bool} val local_thisN = Long_Name.localN ^ Long_Name.separator ^ Auto_Bind.thisN (* gracefully handle huge background theories *) val max_facts_for_duplicates = 50000 val max_facts_for_complex_check = 25000 val max_simps_for_clasimpset = 10000 val no_fact_override = {add = [], del = [], only = false} fun needs_quoting keywords s = Keyword.is_literal keywords s orelse exists (not o Symbol_Pos.is_identifier) (Long_Name.explode s) fun make_name keywords multi j name = (name |> needs_quoting keywords name ? quote) ^ (if multi then "(" ^ string_of_int j ^ ")" else "") fun explode_interval _ (Facts.FromTo (i, j)) = i upto j | explode_interval max (Facts.From i) = i upto i + max - 1 | explode_interval _ (Facts.Single i) = [i] fun is_rec_eq lhs = Term.exists_subterm (curry (op =) (head_of lhs)) fun is_rec_def (\<^const>\Trueprop\ $ t) = is_rec_def t | is_rec_def (\<^const>\Pure.imp\ $ _ $ t2) = is_rec_def t2 | is_rec_def (Const (\<^const_name>\Pure.eq\, _) $ t1 $ t2) = is_rec_eq t1 t2 | is_rec_def (Const (\<^const_name>\HOL.eq\, _) $ t1 $ t2) = is_rec_eq t1 t2 | is_rec_def _ = false fun is_assum assms th = exists (fn ct => Thm.prop_of th aconv Thm.term_of ct) assms fun is_chained chained = member Thm.eq_thm_prop chained fun scope_of_thm global assms chained th = if is_chained chained th then Chained else if global then Global else if is_assum assms th then Assum else Local val may_be_induction = exists_subterm (fn Var (_, Type (\<^type_name>\fun\, [_, T])) => body_type T = \<^typ>\bool\ | _ => false) (* TODO: get rid of *) fun normalize_vars t = let fun normT (Type (s, Ts)) = fold_map normT Ts #>> curry Type s | normT (TVar (z as (_, S))) = (fn ((knownT, nT), accum) => (case find_index (equal z) knownT of ~1 => (TVar ((Name.uu, nT), S), ((z :: knownT, nT + 1), accum)) | j => (TVar ((Name.uu, nT - j - 1), S), ((knownT, nT), accum)))) | normT (T as TFree _) = pair T fun norm (t $ u) = norm t ##>> norm u #>> op $ | norm (Const (s, T)) = normT T #>> curry Const s | norm (Var (z as (_, T))) = normT T #> (fn (T, (accumT, (known, n))) => (case find_index (equal z) known of ~1 => (Var ((Name.uu, n), T), (accumT, (z :: known, n + 1))) | j => (Var ((Name.uu, n - j - 1), T), (accumT, (known, n))))) | norm (Abs (_, T, t)) = norm t ##>> normT T #>> (fn (t, T) => Abs (Name.uu, T, t)) | norm (Bound j) = pair (Bound j) | norm (Free (s, T)) = normT T #>> curry Free s in fst (norm t (([], 0), ([], 0))) end fun status_of_thm css name th = if Termtab.is_empty css then General else let val t = Thm.prop_of th in (* FIXME: use structured name *) if String.isSubstring ".induct" name andalso may_be_induction t then Induction else let val t = normalize_vars t in (case Termtab.lookup css t of SOME status => status | NONE => let val concl = Logic.strip_imp_concl t in (case try (HOLogic.dest_eq o HOLogic.dest_Trueprop) concl of SOME lrhss => let val prems = Logic.strip_imp_prems t val t' = Logic.list_implies (prems, Logic.mk_equals lrhss) in Termtab.lookup css t' |> the_default General end | NONE => General) end) end end fun stature_of_thm global assms chained css name th = (scope_of_thm global assms chained th, status_of_thm css name th) fun fact_of_ref ctxt keywords chained css (xthm as (xref, args)) = let val ths = Attrib.eval_thms ctxt [xthm] val bracket = implode (map (enclose "[" "]" o Pretty.unformatted_string_of o Token.pretty_src ctxt) args) fun nth_name j = (case xref of Facts.Fact s => cartouche (simplify_spaces (YXML.content_of s)) ^ bracket | Facts.Named (("", _), _) => "[" ^ bracket ^ "]" | Facts.Named ((name, _), NONE) => make_name keywords (length ths > 1) (j + 1) name ^ bracket | Facts.Named ((name, _), SOME intervals) => make_name keywords true (nth (maps (explode_interval (length ths)) intervals) j) name ^ bracket) fun add_nth th (j, rest) = let val name = nth_name j in (j + 1, ((name, stature_of_thm false [] chained css name th), th) :: rest) end in (0, []) |> fold add_nth ths |> snd end (* Reject theorems with names like "List.filter.filter_list_def" or "Accessible_Part.acc.defs", as these are definitions arising from packages. *) fun is_package_def s = let val ss = Long_Name.explode s in length ss > 2 andalso not (hd ss = "local") andalso exists (fn suf => String.isSuffix suf s) ["_case_def", "_rec_def", "_size_def", "_size_overloaded_def"] end (* FIXME: put other record thms here, or declare as "no_atp" *) val multi_base_blacklist = ["defs", "select_defs", "update_defs", "split", "splits", "split_asm", "ext_cases", "eq.simps", "eq.refl", "nchotomy", "case_cong", "case_cong_weak", "nat_of_char_simps", "nibble.simps", "nibble.distinct"] |> map (prefix Long_Name.separator) (* The maximum apply depth of any "metis" call in "Metis_Examples" (back in 2007) was 11. *) val max_apply_depth = 18 fun apply_depth (f $ t) = Int.max (apply_depth f, apply_depth t + 1) | apply_depth (Abs (_, _, t)) = apply_depth t | apply_depth _ = 0 fun is_too_complex t = apply_depth t > max_apply_depth (* FIXME: Ad hoc list *) val technical_prefixes = ["ATP", "Code_Evaluation", "Datatype", "Enum", "Lazy_Sequence", "Limited_Sequence", "Meson", "Metis", "Nitpick", "Quickcheck_Random", "Quickcheck_Exhaustive", "Quickcheck_Narrowing", "Random_Sequence", "Sledgehammer", "SMT"] |> map (suffix Long_Name.separator) fun is_technical_const s = exists (fn pref => String.isPrefix pref s) technical_prefixes (* FIXME: make more reliable *) val sep_class_sep = Long_Name.separator ^ "class" ^ Long_Name.separator fun is_low_level_class_const s = s = \<^const_name>\equal_class.equal\ orelse String.isSubstring sep_class_sep s val sep_that = Long_Name.separator ^ Auto_Bind.thatN val skolem_thesis = Name.skolem Auto_Bind.thesisN fun is_that_fact th = exists_subterm (fn Free (s, _) => s = skolem_thesis | _ => false) (Thm.prop_of th) andalso String.isSuffix sep_that (Thm.get_name_hint th) datatype interest = Deal_Breaker | Interesting | Boring fun combine_interests Deal_Breaker _ = Deal_Breaker | combine_interests _ Deal_Breaker = Deal_Breaker | combine_interests Interesting _ = Interesting | combine_interests _ Interesting = Interesting | combine_interests Boring Boring = Boring val type_has_top_sort = exists_subtype (fn TFree (_, []) => true | TVar (_, []) => true | _ => false) fun is_likely_tautology_too_meta_or_too_technical th = let fun is_interesting_subterm (Const (s, _)) = not (member (op =) atp_widely_irrelevant_consts s) | is_interesting_subterm (Free _) = true | is_interesting_subterm _ = false fun interest_of_bool t = if exists_Const ((is_technical_const o fst) orf (is_low_level_class_const o fst) orf type_has_top_sort o snd) t then Deal_Breaker else if exists_type (exists_subtype (curry (op =) \<^typ>\prop\)) t orelse not (exists_subterm is_interesting_subterm t) then Boring else Interesting fun interest_of_prop _ (\<^const>\Trueprop\ $ t) = interest_of_bool t | interest_of_prop Ts (\<^const>\Pure.imp\ $ t $ u) = combine_interests (interest_of_prop Ts t) (interest_of_prop Ts u) | interest_of_prop Ts (Const (\<^const_name>\Pure.all\, _) $ Abs (_, T, t)) = if type_has_top_sort T then Deal_Breaker else interest_of_prop (T :: Ts) t | interest_of_prop Ts ((t as Const (\<^const_name>\Pure.all\, _)) $ u) = interest_of_prop Ts (t $ eta_expand Ts u 1) | interest_of_prop _ (Const (\<^const_name>\Pure.eq\, _) $ t $ u) = combine_interests (interest_of_bool t) (interest_of_bool u) | interest_of_prop _ _ = Deal_Breaker val t = Thm.prop_of th in (interest_of_prop [] t <> Interesting andalso not (Thm.eq_thm_prop (@{thm ext}, th))) orelse is_that_fact th end val is_blacklisted_or_something = let val blist = multi_base_blacklist in fn name => is_package_def name orelse exists (fn s => String.isSuffix s name) blist end (* This is a terrible hack. Free variables are sometimes coded as "M__" when they are displayed as "M" and we want to avoid clashes with these. But sometimes it's even worse: "Ma__" encodes "M". So we simply reserve all prefixes of all free variables. In the worse case scenario, where the fact won't be resolved correctly, the user can fix it manually, e.g., by giving a name to the offending fact. *) fun all_prefixes_of s = map (fn i => String.extract (s, 0, SOME i)) (1 upto size s - 1) fun close_form t = (t, [] |> Term.add_free_names t |> maps all_prefixes_of) |> fold (fn ((s, i), T) => fn (t', taken) => let val s' = singleton (Name.variant_list taken) s in ((if fastype_of t' = HOLogic.boolT then HOLogic.all_const else Logic.all_const) T $ Abs (s', T, abstract_over (Var ((s, i), T), t')), s' :: taken) end) (Term.add_vars t [] |> sort_by (fst o fst)) |> fst fun cartouche_term ctxt = close_form #> hackish_string_of_term ctxt #> cartouche fun cartouche_thm ctxt = cartouche_term ctxt o Thm.prop_of (* TODO: rewrite to use nets and/or to reuse existing data structures *) fun clasimpset_rule_table_of ctxt = let val simps = ctxt |> simpset_of |> dest_ss |> #simps in if length simps >= max_simps_for_clasimpset then Termtab.empty else let fun add stature th = Termtab.update (normalize_vars (Thm.prop_of th), stature) val {safeIs, (* safeEs, *) unsafeIs, (* unsafeEs, *) ...} = ctxt |> claset_of |> Classical.rep_cs val intros = map #1 (Item_Net.content safeIs @ Item_Net.content unsafeIs) (* Add once it is used: val elims = Item_Net.content safeEs @ Item_Net.content unsafeEs |> map Classical.classical_rule *) val specs = Spec_Rules.get ctxt val (rec_defs, nonrec_defs) = specs |> filter (Spec_Rules.is_equational o #rough_classification) |> maps #rules |> List.partition (is_rec_def o Thm.prop_of) val spec_intros = specs |> filter (Spec_Rules.is_relational o #rough_classification) |> maps #rules in Termtab.empty |> fold (add Simp o snd) simps |> fold (add Rec_Def) rec_defs |> fold (add Non_Rec_Def) nonrec_defs (* Add once it is used: |> fold (add Elim) elims *) |> fold (add Intro) intros |> fold (add Inductive) spec_intros end end fun normalize_eq (\<^const>\Trueprop\ $ (t as (t0 as Const (\<^const_name>\HOL.eq\, _)) $ t1 $ t2)) = if is_less_equal (Term_Ord.fast_term_ord (t1, t2)) then t else t0 $ t2 $ t1 | normalize_eq (\<^const>\Trueprop\ $ (t as \<^const>\Not\ $ ((t0 as Const (\<^const_name>\HOL.eq\, _)) $ t1 $ t2))) = if is_less_equal (Term_Ord.fast_term_ord (t1, t2)) then t else HOLogic.mk_not (t0 $ t2 $ t1) | normalize_eq (Const (\<^const_name>\Pure.eq\, Type (_, [T, _])) $ t1 $ t2) = (if is_less_equal (Term_Ord.fast_term_ord (t1, t2)) then (t1, t2) else (t2, t1)) |> (fn (t1, t2) => HOLogic.eq_const T $ t1 $ t2) | normalize_eq t = t fun if_thm_before th th' = if Context.subthy_id (apply2 Thm.theory_id (th, th')) then th else th' (* Hack: Conflate the facts about a class as seen from the outside with the corresponding low-level facts, so that MaSh can learn from the low-level proofs. *) fun un_class_ify s = (case first_field "_class" s of SOME (pref, suf) => [s, pref ^ suf] | NONE => [s]) fun build_name_tables name_of facts = let fun cons_thm (_, th) = Termtab.cons_list (normalize_vars (normalize_eq (Thm.prop_of th)), th) fun add_plain canon alias = Symtab.update (Thm.get_name_hint alias, name_of (if_thm_before canon alias)) fun add_plains (_, aliases as canon :: _) = fold (add_plain canon) aliases fun add_inclass (name, target) = fold (fn s => Symtab.update (s, target)) (un_class_ify name) val prop_tab = fold cons_thm facts Termtab.empty val plain_name_tab = Termtab.fold add_plains prop_tab Symtab.empty val inclass_name_tab = Symtab.fold add_inclass plain_name_tab Symtab.empty in (plain_name_tab, inclass_name_tab) end fun fact_distinct eq facts = fold (fn (i, fact as (_, th)) => Net.insert_term_safe (eq o apply2 (normalize_eq o Thm.prop_of o snd o snd)) (normalize_eq (Thm.prop_of th), (i, fact))) (tag_list 0 facts) Net.empty |> Net.entries |> sort (int_ord o apply2 fst) |> map snd fun struct_induct_rule_on th = (case Logic.strip_horn (Thm.prop_of th) of (prems, \<^const>\Trueprop\ $ ((p as Var ((p_name, 0), _)) $ (a as Var (_, ind_T)))) => if not (is_TVar ind_T) andalso length prems > 1 andalso exists (exists_subterm (curry (op aconv) p)) prems andalso not (exists (exists_subterm (curry (op aconv) a)) prems) then SOME (p_name, ind_T) else NONE | _ => NONE) val instantiate_induct_timeout = seconds 0.01 fun instantiate_induct_rule ctxt concl_prop p_name ((name, stature), th) ind_x = let fun varify_noninducts (t as Free (s, T)) = if (s, T) = ind_x orelse can dest_funT T then t else Var ((s, 0), T) | varify_noninducts t = t val p_inst = concl_prop |> map_aterms varify_noninducts |> close_form |> lambda (Free ind_x) |> hackish_string_of_term ctxt in ((fn () => name () ^ "[where " ^ p_name ^ " = " ^ quote p_inst ^ "]", stature), th |> Rule_Insts.read_instantiate ctxt [(((p_name, 0), Position.none), p_inst)] []) end fun type_match thy (T1, T2) = (Sign.typ_match thy (T2, T1) Vartab.empty; true) handle Type.TYPE_MATCH => false fun instantiate_if_induct_rule ctxt stmt stmt_xs (ax as (_, th)) = (case struct_induct_rule_on th of SOME (p_name, ind_T) => let val thy = Proof_Context.theory_of ctxt in stmt_xs |> filter (fn (_, T) => type_match thy (T, ind_T)) |> map_filter (try (Timeout.apply instantiate_induct_timeout (instantiate_induct_rule ctxt stmt p_name ax))) end | NONE => [ax]) fun external_frees t = [] |> Term.add_frees t |> filter_out (Name.is_internal o fst) fun instantiate_inducts ctxt hyp_ts concl_t = let val ind_stmt = (hyp_ts |> filter_out (null o external_frees), concl_t) |> Logic.list_implies |> Object_Logic.atomize_term ctxt val ind_stmt_xs = external_frees ind_stmt in maps (instantiate_if_induct_rule ctxt ind_stmt ind_stmt_xs) end -fun fact_of_raw_fact ((name, stature), th) = ((name (), stature), th) +fun fact_of_lazy_fact ((name, stature), th) = ((name (), stature), th) fun fact_count facts = Facts.fold_static (K (Integer.add 1)) facts 0 fun is_useful_unnamed_local_fact ctxt = let val thy = Proof_Context.theory_of ctxt val global_facts = Global_Theory.facts_of thy val local_facts = Proof_Context.facts_of ctxt val named_locals = Facts.dest_static true [global_facts] local_facts |> maps (map (normalize_eq o Thm.prop_of) o snd) in fn th => not (Thm.has_name_hint th) andalso not (member (op aconv) named_locals (normalize_eq (Thm.prop_of th))) end fun all_facts ctxt generous keywords add_ths chained css = let val thy = Proof_Context.theory_of ctxt val transfer = Global_Theory.transfer_theories thy val global_facts = Global_Theory.facts_of thy val is_too_complex = if generous orelse fact_count global_facts >= max_facts_for_complex_check then K false else is_too_complex val local_facts = Proof_Context.facts_of ctxt val assms = Assumption.all_assms_of ctxt val named_locals = Facts.dest_static true [global_facts] local_facts val unnamed_locals = Facts.props local_facts |> map #1 |> filter (is_useful_unnamed_local_fact ctxt) |> map (pair "" o single) val full_space = Name_Space.merge (Facts.space_of global_facts, Facts.space_of local_facts) fun add_facts global foldx facts = foldx (fn (name0, ths) => fn accum => if name0 <> "" andalso (Long_Name.is_hidden (Facts.intern facts name0) orelse ((Facts.is_concealed facts name0 orelse (not generous andalso is_blacklisted_or_something name0)) andalso forall (not o member Thm.eq_thm_prop add_ths) ths)) then accum else let val n = length ths val multi = n > 1 fun check_thms a = (case try (Proof_Context.get_thms ctxt) a of NONE => false | SOME ths' => eq_list Thm.eq_thm_prop (ths, ths')) in snd (fold_rev (fn th0 => fn (j, accum) => let val th = transfer th0 in (j - 1, if not (member Thm.eq_thm_prop add_ths th) andalso (is_likely_tautology_too_meta_or_too_technical th orelse is_too_complex (Thm.prop_of th)) then accum else let fun get_name () = if name0 = "" orelse name0 = local_thisN then cartouche_thm ctxt th else let val short_name = Facts.extern ctxt facts name0 in if check_thms short_name then short_name else let val long_name = Name_Space.extern ctxt full_space name0 in if check_thms long_name then long_name else name0 end end |> make_name keywords multi j val stature = stature_of_thm global assms chained css name0 th val new = ((get_name, stature), th) in (if multi then apsnd else apfst) (cons new) accum end) end) ths (n, accum)) end) in (* The single-theorem names go before the multiple-theorem ones (e.g., "xxx" vs. "xxx(3)"), so that single names are preferred when both are available. *) ([], []) |> add_facts false fold local_facts (unnamed_locals @ named_locals) |> add_facts true Facts.fold_static global_facts global_facts |> op @ end fun nearly_all_facts ctxt inst_inducts {add, del, only} keywords css chained hyp_ts concl_t = if only andalso null add then [] else let val chained = chained |> maps (fn th => insert Thm.eq_thm_prop (zero_var_indexes th) [th]) in (if only then maps (map (fn ((name, stature), th) => ((K name, stature), th)) o fact_of_ref ctxt keywords chained css) add else let val (add, del) = apply2 (Attrib.eval_thms ctxt) (add, del) val facts = all_facts ctxt false keywords add chained css |> filter_out ((member Thm.eq_thm_prop del orf (Named_Theorems.member ctxt \<^named_theorems>\no_atp\ andf not o member Thm.eq_thm_prop add)) o snd) in facts end) |> inst_inducts ? instantiate_inducts ctxt hyp_ts concl_t end fun drop_duplicate_facts facts = let val num_facts = length facts in facts |> num_facts <= max_facts_for_duplicates ? fact_distinct (op aconv) end end; diff --git a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML @@ -1,1636 +1,1637 @@ (* Title: HOL/Tools/Sledgehammer/sledgehammer_mash.ML Author: Jasmin Blanchette, TU Muenchen Author: Cezary Kaliszyk, University of Innsbruck Sledgehammer's machine-learning-based relevance filter (MaSh). *) signature SLEDGEHAMMER_MASH = sig type stature = ATP_Problem_Generate.stature - type raw_fact = Sledgehammer_Fact.raw_fact + type lazy_fact = Sledgehammer_Fact.lazy_fact type fact = Sledgehammer_Fact.fact type fact_override = Sledgehammer_Fact.fact_override type params = Sledgehammer_Prover.params type prover_result = Sledgehammer_Prover.prover_result val trace : bool Config.T val duplicates : bool Config.T val MePoN : string val MaShN : string val MeShN : string val mepoN : string val mashN : string val meshN : string val unlearnN : string val learn_isarN : string val learn_proverN : string val relearn_isarN : string val relearn_proverN : string val fact_filters : string list val encode_str : string -> string val encode_strs : string list -> string val decode_str : string -> string val decode_strs : string -> string list datatype mash_algorithm = MaSh_NB | MaSh_kNN | MaSh_NB_kNN | MaSh_NB_Ext | MaSh_kNN_Ext val is_mash_enabled : unit -> bool val the_mash_algorithm : unit -> mash_algorithm val str_of_mash_algorithm : mash_algorithm -> string val mesh_facts : ('a list -> 'a list) -> ('a * 'a -> bool) -> int -> (real * (('a * real) list * 'a list)) list -> 'a list val nickname_of_thm : thm -> string val find_suggested_facts : Proof.context -> ('b * thm) list -> string list -> ('b * thm) list val crude_thm_ord : Proof.context -> thm ord val thm_less : thm * thm -> bool val goal_of_thm : theory -> thm -> thm val run_prover_for_mash : Proof.context -> params -> string -> string -> fact list -> thm -> prover_result val features_of : Proof.context -> string -> stature -> term list -> string list val trim_dependencies : string list -> string list option val isar_dependencies_of : string Symtab.table * string Symtab.table -> thm -> string list option - val prover_dependencies_of : Proof.context -> params -> string -> int -> raw_fact list -> + val prover_dependencies_of : Proof.context -> params -> string -> int -> lazy_fact list -> string Symtab.table * string Symtab.table -> thm -> bool * string list val attach_parents_to_facts : ('a * thm) list -> ('a * thm) list -> (string list * ('a * thm)) list val num_extra_feature_facts : int val extra_feature_factor : real val weight_facts_smoothly : 'a list -> ('a * real) list val weight_facts_steeply : 'a list -> ('a * real) list val find_mash_suggestions : Proof.context -> int -> string list -> ('a * thm) list -> ('a * thm) list -> ('a * thm) list -> ('a * thm) list * ('a * thm) list val mash_suggested_facts : Proof.context -> string -> params -> int -> term list -> term -> - raw_fact list -> fact list * fact list + lazy_fact list -> fact list * fact list val mash_unlearn : Proof.context -> unit val mash_learn_proof : Proof.context -> params -> term -> thm list -> unit val mash_learn_facts : Proof.context -> params -> string -> int -> bool -> Time.time -> - raw_fact list -> string + lazy_fact list -> string val mash_learn : Proof.context -> params -> fact_override -> thm list -> bool -> unit val mash_can_suggest_facts : Proof.context -> bool val mash_can_suggest_facts_fast : Proof.context -> bool val generous_max_suggestions : int -> int val mepo_weight : real val mash_weight : real val relevant_facts : Proof.context -> params -> string -> int -> fact_override -> term list -> - term -> raw_fact list -> (string * fact list) list + term -> lazy_fact list -> (string * fact list) list end; structure Sledgehammer_MaSh : SLEDGEHAMMER_MASH = struct open ATP_Util open ATP_Problem_Generate open Sledgehammer_Util open Sledgehammer_Fact open Sledgehammer_Prover open Sledgehammer_Prover_Minimize open Sledgehammer_MePo val anonymous_proof_prefix = "." val trace = Attrib.setup_config_bool \<^binding>\sledgehammer_mash_trace\ (K false) val duplicates = Attrib.setup_config_bool \<^binding>\sledgehammer_fact_duplicates\ (K false) fun trace_msg ctxt msg = if Config.get ctxt trace then tracing (msg ()) else () fun gen_eq_thm ctxt = if Config.get ctxt duplicates then Thm.eq_thm_strict else Thm.eq_thm_prop val MePoN = "MePo" val MaShN = "MaSh" val MeShN = "MeSh" val mepoN = "mepo" val mashN = "mash" val meshN = "mesh" val fact_filters = [meshN, mepoN, mashN] val unlearnN = "unlearn" val learn_isarN = "learn_isar" val learn_proverN = "learn_prover" val relearn_isarN = "relearn_isar" val relearn_proverN = "relearn_prover" fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i))) type xtab = int * int Symtab.table val empty_xtab = (0, Symtab.empty) fun add_to_xtab key (next, tab) = (next + 1, Symtab.update_new (key, next) tab) fun maybe_add_to_xtab key = perhaps (try (add_to_xtab key)) fun state_file () = Path.expand (Path.explode "$ISABELLE_HOME_USER/mash_state") val remove_state_file = try File.rm o state_file datatype mash_algorithm = MaSh_NB | MaSh_kNN | MaSh_NB_kNN | MaSh_NB_Ext | MaSh_kNN_Ext fun mash_algorithm () = (case Options.default_string \<^system_option>\MaSh\ of "yes" => SOME MaSh_NB_kNN | "sml" => SOME MaSh_NB_kNN | "nb" => SOME MaSh_NB | "knn" => SOME MaSh_kNN | "nb_knn" => SOME MaSh_NB_kNN | "nb_ext" => SOME MaSh_NB_Ext | "knn_ext" => SOME MaSh_kNN_Ext | "none" => NONE | "" => NONE | algorithm => (warning ("Unknown MaSh algorithm: " ^ quote algorithm); NONE)) val is_mash_enabled = is_some o mash_algorithm val the_mash_algorithm = the_default MaSh_NB_kNN o mash_algorithm fun str_of_mash_algorithm MaSh_NB = "nb" | str_of_mash_algorithm MaSh_kNN = "knn" | str_of_mash_algorithm MaSh_NB_kNN = "nb_knn" | str_of_mash_algorithm MaSh_NB_Ext = "nb_ext" | str_of_mash_algorithm MaSh_kNN_Ext = "knn_ext" fun scaled_avg [] = 0 | scaled_avg xs = Real.ceil (100000000.0 * fold (curry (op +)) xs 0.0) div length xs fun avg [] = 0.0 | avg xs = fold (curry (op +)) xs 0.0 / Real.fromInt (length xs) fun normalize_scores _ [] = [] | normalize_scores max_facts xs = map (apsnd (curry (op *) (1.0 / avg (map snd (take max_facts xs))))) xs fun mesh_facts maybe_distinct _ max_facts [(_, (sels, unks))] = map fst (take max_facts sels) @ take (max_facts - length sels) unks |> maybe_distinct | mesh_facts _ fact_eq max_facts mess = let val mess = mess |> map (apsnd (apfst (normalize_scores max_facts))) fun score_in fact (global_weight, (sels, unks)) = let val score_at = try (nth sels) #> Option.map (fn (_, score) => global_weight * score) in (case find_index (curry fact_eq fact o fst) sels of ~1 => if member fact_eq unks fact then NONE else SOME 0.0 | rank => score_at rank) end fun weight_of fact = mess |> map_filter (score_in fact) |> scaled_avg in fold (union fact_eq o map fst o take max_facts o fst o snd) mess [] |> map (`weight_of) |> sort (int_ord o apply2 fst o swap) |> map snd |> take max_facts end fun smooth_weight_of_fact rank = Math.pow (1.3, 15.5 - 0.2 * Real.fromInt rank) + 15.0 (* FUDGE *) fun steep_weight_of_fact rank = Math.pow (0.62, log2 (Real.fromInt (rank + 1))) (* FUDGE *) fun weight_facts_smoothly facts = facts ~~ map smooth_weight_of_fact (0 upto length facts - 1) fun weight_facts_steeply facts = facts ~~ map steep_weight_of_fact (0 upto length facts - 1) fun sort_array_suffix cmp needed a = let exception BOTTOM of int val al = Array.length a fun maxson l i = let val i31 = i + i + i + 1 in if i31 + 2 < l then let val x = Unsynchronized.ref i31 in if is_less (cmp (Array.sub (a, i31), Array.sub (a, i31 + 1))) then x := i31 + 1 else (); if is_less (cmp (Array.sub (a, !x), Array.sub (a, i31 + 2))) then x := i31 + 2 else (); !x end else if i31 + 1 < l andalso is_less (cmp (Array.sub (a, i31), Array.sub (a, i31 + 1))) then i31 + 1 else if i31 < l then i31 else raise BOTTOM i end fun trickledown l i e = let val j = maxson l i in if is_greater (cmp (Array.sub (a, j), e)) then (Array.update (a, i, Array.sub (a, j)); trickledown l j e) else Array.update (a, i, e) end fun trickle l i e = trickledown l i e handle BOTTOM i => Array.update (a, i, e) fun bubbledown l i = let val j = maxson l i in Array.update (a, i, Array.sub (a, j)); bubbledown l j end fun bubble l i = bubbledown l i handle BOTTOM i => i fun trickleup i e = let val father = (i - 1) div 3 in if is_less (cmp (Array.sub (a, father), e)) then (Array.update (a, i, Array.sub (a, father)); if father > 0 then trickleup father e else Array.update (a, 0, e)) else Array.update (a, i, e) end fun for i = if i < 0 then () else (trickle al i (Array.sub (a, i)); for (i - 1)) fun for2 i = if i < Integer.max 2 (al - needed) then () else let val e = Array.sub (a, i) in Array.update (a, i, Array.sub (a, 0)); trickleup (bubble i 0) e; for2 (i - 1) end in for (((al + 1) div 3) - 1); for2 (al - 1); if al > 1 then let val e = Array.sub (a, 1) in Array.update (a, 1, Array.sub (a, 0)); Array.update (a, 0, e) end else () end fun rev_sort_list_prefix cmp needed xs = let val ary = Array.fromList xs in sort_array_suffix cmp needed ary; Array.foldl (op ::) [] ary end (*** Convenience functions for synchronized access ***) fun synchronized_timed_value var time_limit = Synchronized.timed_access var time_limit (fn value => SOME (value, value)) fun synchronized_timed_change_result var time_limit f = Synchronized.timed_access var time_limit (SOME o f) fun synchronized_timed_change var time_limit f = synchronized_timed_change_result var time_limit (fn x => ((), f x)) fun mash_time_limit _ = SOME (seconds 0.1) (*** Isabelle-agnostic machine learning ***) structure MaSh = struct fun select_fact_idxs (big_number : real) recommends = List.app (fn at => let val (j, ov) = Array.sub (recommends, at) in Array.update (recommends, at, (j, big_number + ov)) end) fun wider_array_of_vector init vec = let val ary = Array.array init in Array.copyVec {src = vec, dst = ary, di = 0}; ary end val nb_def_prior_weight = 1000 (* FUDGE *) fun learn_facts (tfreq0, sfreq0, dffreq0) num_facts0 num_facts num_feats depss featss = let val tfreq = wider_array_of_vector (num_facts, 0) tfreq0 val sfreq = wider_array_of_vector (num_facts, Inttab.empty) sfreq0 val dffreq = wider_array_of_vector (num_feats, 0) dffreq0 fun learn_one th feats deps = let fun add_th weight t = let val im = Array.sub (sfreq, t) fun fold_fn s = Inttab.map_default (s, 0) (Integer.add weight) in map_array_at tfreq (Integer.add weight) t; Array.update (sfreq, t, fold fold_fn feats im) end val add_sym = map_array_at dffreq (Integer.add 1) in add_th nb_def_prior_weight th; List.app (add_th 1) deps; List.app add_sym feats end fun for i = if i = num_facts then () else (learn_one i (Vector.sub (featss, i)) (Vector.sub (depss, i)); for (i + 1)) in for num_facts0; (Array.vector tfreq, Array.vector sfreq, Array.vector dffreq) end fun naive_bayes (tfreq, sfreq, dffreq) num_facts max_suggs fact_idxs goal_feats = let val tau = 0.2 (* FUDGE *) val pos_weight = 5.0 (* FUDGE *) val def_val = ~18.0 (* FUDGE *) val init_val = 30.0 (* FUDGE *) val ln_afreq = Math.ln (Real.fromInt num_facts) val idf = Vector.map (fn i => ln_afreq - Math.ln (Real.fromInt i)) dffreq fun tfidf feat = Vector.sub (idf, feat) fun log_posterior i = let val tfreq = Real.fromInt (Vector.sub (tfreq, i)) fun add_feat (f, fw0) (res, sfh) = (case Inttab.lookup sfh f of SOME sf => (res + fw0 * tfidf f * Math.ln (pos_weight * Real.fromInt sf / tfreq), Inttab.delete f sfh) | NONE => (res + fw0 * tfidf f * def_val, sfh)) val (res, sfh) = fold add_feat goal_feats (init_val * Math.ln tfreq, Vector.sub (sfreq, i)) fun fold_sfh (f, sf) sow = sow + tfidf f * Math.ln (1.0 - Real.fromInt (sf - 1) / tfreq) val sum_of_weights = Inttab.fold fold_sfh sfh 0.0 in res + tau * sum_of_weights end val posterior = Array.tabulate (num_facts, (fn j => (j, log_posterior j))) fun ret at acc = if at = num_facts then acc else ret (at + 1) (Array.sub (posterior, at) :: acc) in select_fact_idxs 100000.0 posterior fact_idxs; sort_array_suffix (Real.compare o apply2 snd) max_suggs posterior; ret (Integer.max 0 (num_facts - max_suggs)) [] end val initial_k = 0 fun k_nearest_neighbors dffreq num_facts num_feats depss featss max_suggs fact_idxs goal_feats = let exception EXIT of unit val ln_afreq = Math.ln (Real.fromInt num_facts) fun tfidf feat = ln_afreq - Math.ln (Real.fromInt (Vector.sub (dffreq, feat))) val overlaps_sqr = Array.tabulate (num_facts, rpair 0.0) val feat_facts = Array.array (num_feats, []) val _ = Vector.foldl (fn (feats, fact) => (List.app (map_array_at feat_facts (cons fact)) feats; fact + 1)) 0 featss fun do_feat (s, sw0) = let val sw = sw0 * tfidf s val w6 = Math.pow (sw, 6.0 (* FUDGE *)) fun inc_overlap j = let val (_, ov) = Array.sub (overlaps_sqr, j) in Array.update (overlaps_sqr, j, (j, w6 + ov)) end in List.app inc_overlap (Array.sub (feat_facts, s)) end val _ = List.app do_feat goal_feats val _ = sort_array_suffix (Real.compare o apply2 snd) num_facts overlaps_sqr val no_recommends = Unsynchronized.ref 0 val recommends = Array.tabulate (num_facts, rpair 0.0) val age = Unsynchronized.ref 500000000.0 fun inc_recommend v j = let val (_, ov) = Array.sub (recommends, j) in if ov <= 0.0 then (no_recommends := !no_recommends + 1; Array.update (recommends, j, (j, !age + ov))) else Array.update (recommends, j, (j, v + ov)) end val k = Unsynchronized.ref 0 fun do_k k = if k >= num_facts then raise EXIT () else let val deps_factor = 2.7 (* FUDGE *) val (j, o2) = Array.sub (overlaps_sqr, num_facts - k - 1) val _ = inc_recommend o2 j val ds = Vector.sub (depss, j) val l = Real.fromInt (length ds) in List.app (inc_recommend (deps_factor * o2 / l)) ds end fun while1 () = if !k = initial_k + 1 then () else (do_k (!k); k := !k + 1; while1 ()) handle EXIT () => () fun while2 () = if !no_recommends >= max_suggs then () else (do_k (!k); k := !k + 1; age := !age - 10000.0; while2 ()) handle EXIT () => () fun ret acc at = if at = num_facts then acc else ret (Array.sub (recommends, at) :: acc) (at + 1) in while1 (); while2 (); select_fact_idxs 1000000000.0 recommends fact_idxs; sort_array_suffix (Real.compare o apply2 snd) max_suggs recommends; ret [] (Integer.max 0 (num_facts - max_suggs)) end (* experimental *) fun external_tool tool max_suggs learns goal_feats = let val ser = string_of_int (serial ()) (* poor person's attempt at thread-safety *) val ocs = TextIO.openOut ("adv_syms" ^ ser) val ocd = TextIO.openOut ("adv_deps" ^ ser) val ocq = TextIO.openOut ("adv_seq" ^ ser) val occ = TextIO.openOut ("adv_conj" ^ ser) fun os oc s = TextIO.output (oc, s) fun ol _ _ _ [] = () | ol _ f _ [e] = f e | ol oc f sep (h :: t) = (f h; os oc sep; ol oc f sep t) fun do_learn (name, feats, deps) = (os ocs name; os ocs ":"; ol ocs (os ocs o quote) ", " feats; os ocs "\n"; os ocd name; os ocd ":"; ol ocd (os ocd) " " deps; os ocd "\n"; os ocq name; os ocq "\n") fun forkexec no = let val cmd = "~/misc/" ^ tool ^ " adv_syms" ^ ser ^ " adv_deps" ^ ser ^ " " ^ string_of_int no ^ " adv_seq" ^ ser ^ " < adv_conj" ^ ser in fst (Isabelle_System.bash_output cmd) |> space_explode " " |> filter_out (curry (op =) "") end in (List.app do_learn learns; ol occ (os occ o quote) ", " (map fst goal_feats); TextIO.closeOut ocs; TextIO.closeOut ocd; TextIO.closeOut ocq; TextIO.closeOut occ; forkexec max_suggs) end fun k_nearest_neighbors_ext max_suggs = external_tool ("newknn/knn" ^ " " ^ string_of_int initial_k) max_suggs fun naive_bayes_ext max_suggs = external_tool "predict/nbayes" max_suggs fun query_external ctxt algorithm max_suggs learns goal_feats = (trace_msg ctxt (fn () => "MaSh query external " ^ commas (map fst goal_feats)); (case algorithm of MaSh_NB_Ext => naive_bayes_ext max_suggs learns goal_feats | MaSh_kNN_Ext => k_nearest_neighbors_ext max_suggs learns goal_feats)) fun query_internal ctxt algorithm num_facts num_feats (fact_names, featss, depss) (freqs as (_, _, dffreq)) fact_idxs max_suggs goal_feats int_goal_feats = let fun nb () = naive_bayes freqs num_facts max_suggs fact_idxs int_goal_feats |> map fst fun knn () = k_nearest_neighbors dffreq num_facts num_feats depss featss max_suggs fact_idxs int_goal_feats |> map fst in (trace_msg ctxt (fn () => "MaSh query internal " ^ commas (map fst goal_feats) ^ " from {" ^ elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] fact_names)) ^ "}"); (case algorithm of MaSh_NB => nb () | MaSh_kNN => knn () | MaSh_NB_kNN => mesh_facts I (op =) max_suggs [(0.5 (* FUDGE *), (weight_facts_steeply (nb ()), [])), (0.5 (* FUDGE *), (weight_facts_steeply (knn ()), []))]) |> map (curry Vector.sub fact_names)) end end; (*** Persistent, stringly-typed state ***) fun meta_char c = if Char.isAlphaNum c orelse c = #"_" orelse c = #"." orelse c = #"(" orelse c = #")" orelse c = #"," orelse c = #"'" then String.str c else (* fixed width, in case more digits follow *) "%" ^ stringN_of_int 3 (Char.ord c) fun unmeta_chars accum [] = String.implode (rev accum) | unmeta_chars accum (#"%" :: d1 :: d2 :: d3 :: cs) = (case Int.fromString (String.implode [d1, d2, d3]) of SOME n => unmeta_chars (Char.chr n :: accum) cs | NONE => "" (* error *)) | unmeta_chars _ (#"%" :: _) = "" (* error *) | unmeta_chars accum (c :: cs) = unmeta_chars (c :: accum) cs val encode_str = String.translate meta_char val encode_strs = map encode_str #> space_implode " " fun decode_str s = if String.isSubstring "%" s then unmeta_chars [] (String.explode s) else s; fun decode_strs s = space_explode " " s |> String.isSubstring "%" s ? map decode_str; datatype proof_kind = Isar_Proof | Automatic_Proof | Isar_Proof_wegen_Prover_Flop fun str_of_proof_kind Isar_Proof = "i" | str_of_proof_kind Automatic_Proof = "a" | str_of_proof_kind Isar_Proof_wegen_Prover_Flop = "x" fun proof_kind_of_str "a" = Automatic_Proof | proof_kind_of_str "x" = Isar_Proof_wegen_Prover_Flop | proof_kind_of_str _ (* "i" *) = Isar_Proof fun add_edge_to name parent = Graph.default_node (parent, (Isar_Proof, [], [])) #> Graph.add_edge (parent, name) fun add_node kind name parents feats deps (accum as (access_G, (fact_xtab, feat_xtab), learns)) = let val fact_xtab' = add_to_xtab name fact_xtab in ((Graph.new_node (name, (kind, feats, deps)) access_G handle Graph.DUP _ => Graph.map_node name (K (kind, feats, deps)) access_G) |> fold (add_edge_to name) parents, (fact_xtab', fold maybe_add_to_xtab feats feat_xtab), (name, feats, deps) :: learns) end handle Symtab.DUP _ => accum (* robustness (in case the state file violates the invariant) *) fun try_graph ctxt when def f = f () handle Graph.CYCLES (cycle :: _) => (trace_msg ctxt (fn () => "Cycle involving " ^ commas cycle ^ " when " ^ when); def) | Graph.DUP name => (trace_msg ctxt (fn () => "Duplicate fact " ^ quote name ^ " when " ^ when); def) | Graph.UNDEF name => (trace_msg ctxt (fn () => "Unknown fact " ^ quote name ^ " when " ^ when); def) | exn => if Exn.is_interrupt exn then Exn.reraise exn else (trace_msg ctxt (fn () => "Internal error when " ^ when ^ ":\n" ^ Runtime.exn_message exn); def) fun graph_info G = string_of_int (length (Graph.keys G)) ^ " node(s), " ^ string_of_int (fold (Integer.add o length o snd) (Graph.dest G) 0) ^ " edge(s), " ^ string_of_int (length (Graph.maximals G)) ^ " maximal" type ffds = string vector * int list vector * int list vector type freqs = int vector * int Inttab.table vector * int vector type mash_state = {access_G : (proof_kind * string list * string list) Graph.T, xtabs : xtab * xtab, ffds : ffds, freqs : freqs, dirty_facts : string list option} val empty_xtabs = (empty_xtab, empty_xtab) val empty_ffds = (Vector.fromList [], Vector.fromList [], Vector.fromList []) : ffds val empty_freqs = (Vector.fromList [], Vector.fromList [], Vector.fromList []) : freqs val empty_state = {access_G = Graph.empty, xtabs = empty_xtabs, ffds = empty_ffds, freqs = empty_freqs, dirty_facts = SOME []} : mash_state fun recompute_ffds_freqs_from_learns (learns : (string * string list * string list) list) ((num_facts, fact_tab), (num_feats, feat_tab)) num_facts0 (fact_names0, featss0, depss0) freqs0 = let val fact_names = Vector.concat [fact_names0, Vector.fromList (map #1 learns)] val featss = Vector.concat [featss0, Vector.fromList (map (map_filter (Symtab.lookup feat_tab) o #2) learns)] val depss = Vector.concat [depss0, Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns)] in ((fact_names, featss, depss), MaSh.learn_facts freqs0 num_facts0 num_facts num_feats depss featss) end fun reorder_learns (num_facts, fact_tab) learns = let val ary = Array.array (num_facts, ("", [], [])) in List.app (fn learn as (fact, _, _) => Array.update (ary, the (Symtab.lookup fact_tab fact), learn)) learns; Array.foldr (op ::) [] ary end fun recompute_ffds_freqs_from_access_G access_G (xtabs as (fact_xtab, _)) = let val learns = Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G |> reorder_learns fact_xtab in recompute_ffds_freqs_from_learns learns xtabs 0 empty_ffds empty_freqs end local val version = "*** MaSh version 20190121 ***" exception FILE_VERSION_TOO_NEW of unit fun extract_node line = (case space_explode ":" line of [head, tail] => (case (space_explode " " head, map (unprefix " ") (space_explode ";" tail)) of ([kind, name], [parents, feats, deps]) => SOME (proof_kind_of_str kind, decode_str name, decode_strs parents, decode_strs feats, decode_strs deps) | _ => NONE) | _ => NONE) fun would_load_state (memory_time, _) = let val path = state_file () in (case try OS.FileSys.modTime (File.platform_path path) of NONE => false | SOME disk_time => memory_time < disk_time) end; fun load_state ctxt (time_state as (memory_time, _)) = let val path = state_file () in (case try OS.FileSys.modTime (File.platform_path path) of NONE => time_state | SOME disk_time => if memory_time >= disk_time then time_state else (disk_time, (case try File.read_lines path of SOME (version' :: node_lines) => let fun extract_line_and_add_node line = (case extract_node line of NONE => I (* should not happen *) | SOME (kind, name, parents, feats, deps) => add_node kind name parents feats deps) val empty_G_etc = (Graph.empty, empty_xtabs, []) val (access_G, xtabs, rev_learns) = (case string_ord (version', version) of EQUAL => try_graph ctxt "loading state" empty_G_etc (fn () => fold extract_line_and_add_node node_lines empty_G_etc) | LESS => (remove_state_file (); empty_G_etc) (* cannot parse old file *) | GREATER => raise FILE_VERSION_TOO_NEW ()) val (ffds, freqs) = recompute_ffds_freqs_from_learns (rev rev_learns) xtabs 0 empty_ffds empty_freqs in trace_msg ctxt (fn () => "Loaded fact graph (" ^ graph_info access_G ^ ")"); {access_G = access_G, xtabs = xtabs, ffds = ffds, freqs = freqs, dirty_facts = SOME []} end | _ => empty_state))) end fun str_of_entry (kind, name, parents, feats, deps) = str_of_proof_kind kind ^ " " ^ encode_str name ^ ": " ^ encode_strs parents ^ "; " ^ encode_strs feats ^ "; " ^ encode_strs deps ^ "\n" fun save_state _ (time_state as (_, {dirty_facts = SOME [], ...})) = time_state | save_state ctxt (memory_time, {access_G, xtabs, ffds, freqs, dirty_facts}) = let fun append_entry (name, ((kind, feats, deps), (parents, _))) = cons (kind, name, Graph.Keys.dest parents, feats, deps) val path = state_file () val dirty_facts' = (case try OS.FileSys.modTime (File.platform_path path) of NONE => NONE | SOME disk_time => if disk_time <= memory_time then dirty_facts else NONE) val (banner, entries) = (case dirty_facts' of SOME names => (NONE, fold (append_entry o Graph.get_entry access_G) names []) | NONE => (SOME (version ^ "\n"), Graph.fold append_entry access_G [])) in (case banner of SOME s => File.write path s | NONE => (); entries |> chunk_list 500 |> List.app (File.append path o implode o map str_of_entry)) handle IO.Io _ => (); trace_msg ctxt (fn () => "Saved fact graph (" ^ graph_info access_G ^ (case dirty_facts of SOME dirty_facts => "; " ^ string_of_int (length dirty_facts) ^ " dirty fact(s)" | _ => "") ^ ")"); (Time.now (), {access_G = access_G, xtabs = xtabs, ffds = ffds, freqs = freqs, dirty_facts = SOME []}) end val global_state = Synchronized.var "Sledgehammer_MaSh.global_state" (Time.zeroTime, empty_state) in fun map_state ctxt f = (trace_msg ctxt (fn () => "Changing MaSh state"); synchronized_timed_change global_state mash_time_limit (load_state ctxt ##> f #> save_state ctxt)) |> ignore handle FILE_VERSION_TOO_NEW () => () fun peek_state ctxt = (trace_msg ctxt (fn () => "Peeking at MaSh state"); (case synchronized_timed_value global_state mash_time_limit of NONE => NONE | SOME state => if would_load_state state then NONE else SOME state)) fun get_state ctxt = (trace_msg ctxt (fn () => "Retrieving MaSh state"); synchronized_timed_change_result global_state mash_time_limit (perhaps (try (load_state ctxt)) #> `snd)) fun clear_state ctxt = (trace_msg ctxt (fn () => "Clearing MaSh state"); Synchronized.change global_state (fn _ => (remove_state_file (); (Time.zeroTime, empty_state)))) end (*** Isabelle helpers ***) fun crude_printed_term size t = let fun term _ (res, 0) = (res, 0) | term (t $ u) (res, size) = let val (res, size) = term t (res ^ "(", size) val (res, size) = term u (res ^ " ", size) in (res ^ ")", size) end | term (Abs (s, _, t)) (res, size) = term t (res ^ "%" ^ s ^ ".", size - 1) | term (Bound n) (res, size) = (res ^ "#" ^ string_of_int n, size - 1) | term (Const (s, _)) (res, size) = (res ^ Long_Name.base_name s, size - 1) | term (Free (s, _)) (res, size) = (res ^ s, size - 1) | term (Var ((s, _), _)) (res, size) = (res ^ s, size - 1) in fst (term t ("", size)) end fun nickname_of_thm th = if Thm.has_name_hint th then let val hint = Thm.get_name_hint th in (* There must be a better way to detect local facts. *) (case Long_Name.dest_local hint of SOME suf => Long_Name.implode [Thm.theory_name th, suf, crude_printed_term 25 (Thm.prop_of th)] | NONE => hint) end else crude_printed_term 50 (Thm.prop_of th) fun find_suggested_facts ctxt facts = let fun add (fact as (_, th)) = Symtab.default (nickname_of_thm th, fact) val tab = fold add facts Symtab.empty fun lookup nick = Symtab.lookup tab nick |> tap (fn NONE => trace_msg ctxt (fn () => "Cannot find " ^ quote nick) | _ => ()) in map_filter lookup end fun free_feature_of s = "f" ^ s fun thy_feature_of s = "y" ^ s fun type_feature_of s = "t" ^ s fun class_feature_of s = "s" ^ s val local_feature = "local" fun crude_thm_ord ctxt = let val ancestor_lengths = fold (fn thy => Symtab.update (Context.theory_name thy, length (Context.ancestors_of thy))) (Theory.nodes_of (Proof_Context.theory_of ctxt)) Symtab.empty val ancestor_length = Symtab.lookup ancestor_lengths o Context.theory_id_name fun crude_theory_ord p = if Context.eq_thy_id p then EQUAL else if Context.proper_subthy_id p then LESS else if Context.proper_subthy_id (swap p) then GREATER else (case apply2 ancestor_length p of (SOME m, SOME n) => (case int_ord (m, n) of EQUAL => string_ord (apply2 Context.theory_id_name p) | ord => ord) | _ => string_ord (apply2 Context.theory_id_name p)) in fn p => (case crude_theory_ord (apply2 Thm.theory_id p) of EQUAL => (* The hack below is necessary because of odd dependencies that are not reflected in the theory comparison. *) let val q = apply2 nickname_of_thm p in (* Hack to put "xxx_def" before "xxxI" and "xxxE" *) (case bool_ord (apply2 (String.isSuffix "_def") (swap q)) of EQUAL => string_ord q | ord => ord) end | ord => ord) end; val thm_less_eq = Context.subthy_id o apply2 Thm.theory_id fun thm_less p = thm_less_eq p andalso not (thm_less_eq (swap p)) val freezeT = Type.legacy_freeze_type fun freeze (t $ u) = freeze t $ freeze u | freeze (Abs (s, T, t)) = Abs (s, freezeT T, freeze t) | freeze (Var ((s, _), T)) = Free (s, freezeT T) | freeze (Const (s, T)) = Const (s, freezeT T) | freeze (Free (s, T)) = Free (s, freezeT T) | freeze t = t fun goal_of_thm thy = Thm.prop_of #> freeze #> Thm.global_cterm_of thy #> Goal.init fun run_prover_for_mash ctxt params prover goal_name facts goal = let val problem = {comment = "Goal: " ^ goal_name, state = Proof.init ctxt, goal = goal, subgoal = 1, subgoal_count = 1, factss = [("", facts)], found_proof = I} in get_minimizing_prover ctxt MaSh (K ()) prover params problem end val bad_types = [\<^type_name>\prop\, \<^type_name>\bool\, \<^type_name>\fun\] val crude_str_of_sort = space_implode "," o map Long_Name.base_name o subtract (op =) \<^sort>\type\ fun crude_str_of_typ (Type (s, [])) = Long_Name.base_name s | crude_str_of_typ (Type (s, Ts)) = Long_Name.base_name s ^ implode (map crude_str_of_typ Ts) | crude_str_of_typ (TFree (_, S)) = crude_str_of_sort S | crude_str_of_typ (TVar (_, S)) = crude_str_of_sort S fun maybe_singleton_str "" = [] | maybe_singleton_str s = [s] val max_pat_breadth = 5 (* FUDGE *) fun term_features_of ctxt thy_name term_max_depth type_max_depth ts = let val thy = Proof_Context.theory_of ctxt val fixes = map snd (Variable.dest_fixes ctxt) val classes = Sign.classes_of thy fun add_classes \<^sort>\type\ = I | add_classes S = fold (`(Sorts.super_classes classes) #> swap #> op :: #> subtract (op =) \<^sort>\type\ #> map class_feature_of #> union (op =)) S fun pattify_type 0 _ = [] | pattify_type _ (Type (s, [])) = if member (op =) bad_types s then [] else [s] | pattify_type depth (Type (s, U :: Ts)) = let val T = Type (s, Ts) val ps = take max_pat_breadth (pattify_type depth T) val qs = take max_pat_breadth ("" :: pattify_type (depth - 1) U) in map_product (fn p => fn "" => p | q => p ^ "(" ^ q ^ ")") ps qs end | pattify_type _ (TFree (_, S)) = maybe_singleton_str (crude_str_of_sort S) | pattify_type _ (TVar (_, S)) = maybe_singleton_str (crude_str_of_sort S) fun add_type_pat depth T = union (op =) (map type_feature_of (pattify_type depth T)) fun add_type_pats 0 _ = I | add_type_pats depth t = add_type_pat depth t #> add_type_pats (depth - 1) t fun add_type T = add_type_pats type_max_depth T #> fold_atyps_sorts (add_classes o snd) T fun add_subtypes (T as Type (_, Ts)) = add_type T #> fold add_subtypes Ts | add_subtypes T = add_type T fun pattify_term _ 0 _ = [] | pattify_term _ _ (Const (s, _)) = if is_widely_irrelevant_const s then [] else [s] | pattify_term _ _ (Free (s, T)) = maybe_singleton_str (crude_str_of_typ T) |> (if member (op =) fixes s then cons (free_feature_of (Long_Name.append thy_name s)) else I) | pattify_term _ _ (Var (_, T)) = maybe_singleton_str (crude_str_of_typ T) | pattify_term Ts _ (Bound j) = maybe_singleton_str (crude_str_of_typ (nth Ts j)) | pattify_term Ts depth (t $ u) = let val ps = take max_pat_breadth (pattify_term Ts depth t) val qs = take max_pat_breadth ("" :: pattify_term Ts (depth - 1) u) in map_product (fn p => fn "" => p | q => p ^ "(" ^ q ^ ")") ps qs end | pattify_term _ _ _ = [] fun add_term_pat Ts = union (op =) oo pattify_term Ts fun add_term_pats _ 0 _ = I | add_term_pats Ts depth t = add_term_pat Ts depth t #> add_term_pats Ts (depth - 1) t fun add_term Ts = add_term_pats Ts term_max_depth fun add_subterms Ts t = (case strip_comb t of (Const (s, T), args) => (not (is_widely_irrelevant_const s) ? add_term Ts t) #> add_subtypes T #> fold (add_subterms Ts) args | (head, args) => (case head of Free (_, T) => add_term Ts t #> add_subtypes T | Var (_, T) => add_subtypes T | Abs (_, T, body) => add_subtypes T #> add_subterms (T :: Ts) body | _ => I) #> fold (add_subterms Ts) args) in fold (add_subterms []) ts [] end val term_max_depth = 2 val type_max_depth = 1 (* TODO: Generate type classes for types? *) fun features_of ctxt thy_name (scope, _) ts = thy_feature_of thy_name :: term_features_of ctxt thy_name term_max_depth type_max_depth ts |> scope <> Global ? cons local_feature (* Too many dependencies is a sign that a decision procedure is at work. There is not much to learn from such proofs. *) val max_dependencies = 20 (* FUDGE *) val prover_default_max_facts = 25 (* FUDGE *) (* "type_definition_xxx" facts are characterized by their use of "CollectI". *) val typedef_dep = nickname_of_thm @{thm CollectI} (* Mysterious parts of the class machinery create lots of proofs that refer exclusively to "someI_ex" (and to some internal constructions). *) val class_some_dep = nickname_of_thm @{thm someI_ex} val fundef_ths = @{thms fundef_ex1_existence fundef_ex1_uniqueness fundef_ex1_iff fundef_default_value} |> map nickname_of_thm (* "Rep_xxx_inject", "Abs_xxx_inverse", etc., are derived using these facts. *) val typedef_ths = @{thms type_definition.Abs_inverse type_definition.Rep_inverse type_definition.Rep type_definition.Rep_inject type_definition.Abs_inject type_definition.Rep_cases type_definition.Abs_cases type_definition.Rep_induct type_definition.Abs_induct type_definition.Rep_range type_definition.Abs_image} |> map nickname_of_thm fun is_size_def [dep] th = (case first_field ".rec" dep of SOME (pref, _) => (case first_field ".size" (nickname_of_thm th) of SOME (pref', _) => pref = pref' | NONE => false) | NONE => false) | is_size_def _ _ = false fun trim_dependencies deps = if length deps > max_dependencies then NONE else SOME deps fun isar_dependencies_of name_tabs th = thms_in_proof max_dependencies (SOME name_tabs) th |> Option.map (fn deps => if deps = [typedef_dep] orelse deps = [class_some_dep] orelse exists (member (op =) fundef_ths) deps orelse exists (member (op =) typedef_ths) deps orelse is_size_def deps th then [] else deps) fun prover_dependencies_of ctxt (params as {verbose, max_facts, ...}) prover auto_level facts name_tabs th = (case isar_dependencies_of name_tabs th of SOME [] => (false, []) | isar_deps0 => let val isar_deps = these isar_deps0 val thy = Proof_Context.theory_of ctxt val goal = goal_of_thm thy th val name = nickname_of_thm th val (_, hyp_ts, concl_t) = ATP_Util.strip_subgoal goal 1 ctxt val facts = facts |> filter (fn (_, th') => thm_less (th', th)) fun nickify ((_, stature), th) = ((nickname_of_thm th, stature), th) fun is_dep dep (_, th) = (nickname_of_thm th = dep) fun add_isar_dep facts dep accum = if exists (is_dep dep) accum then accum else (case find_first (is_dep dep) facts of SOME ((_, status), th) => accum @ [(("", status), th)] | NONE => accum (* should not happen *)) val mepo_facts = facts |> mepo_suggested_facts ctxt params (max_facts |> the_default prover_default_max_facts) NONE hyp_ts concl_t val facts = mepo_facts |> fold (add_isar_dep facts) isar_deps |> map nickify val num_isar_deps = length isar_deps in if verbose andalso auto_level = 0 then writeln ("MaSh: " ^ quote prover ^ " on " ^ quote name ^ " with " ^ string_of_int num_isar_deps ^ " + " ^ string_of_int (length facts - num_isar_deps) ^ " facts") else (); (case run_prover_for_mash ctxt params prover name facts goal of {outcome = NONE, used_facts, ...} => (if verbose andalso auto_level = 0 then let val num_facts = length used_facts in writeln ("Found proof with " ^ string_of_int num_facts ^ " fact" ^ plural_s num_facts) end else (); (true, map fst used_facts)) | _ => (false, isar_deps)) end) (*** High-level communication with MaSh ***) (* In the following functions, chunks are risers w.r.t. "thm_less_eq". *) fun chunks_and_parents_for chunks th = let fun insert_parent new parents = let val parents = parents |> filter_out (fn p => thm_less_eq (p, new)) in parents |> forall (fn p => not (thm_less_eq (new, p))) parents ? cons new end fun rechunk seen (rest as th' :: ths) = if thm_less_eq (th', th) then (rev seen, rest) else rechunk (th' :: seen) ths fun do_chunk [] accum = accum | do_chunk (chunk as hd_chunk :: _) (chunks, parents) = if thm_less_eq (hd_chunk, th) then (chunk :: chunks, insert_parent hd_chunk parents) else if thm_less_eq (List.last chunk, th) then let val (front, back as hd_back :: _) = rechunk [] chunk in (front :: back :: chunks, insert_parent hd_back parents) end else (chunk :: chunks, parents) in fold_rev do_chunk chunks ([], []) |>> cons [] ||> map nickname_of_thm end fun attach_parents_to_facts _ [] = [] | attach_parents_to_facts old_facts (facts as (_, th) :: _) = let fun do_facts _ [] = [] | do_facts (_, parents) [fact] = [(parents, fact)] | do_facts (chunks, parents) ((fact as (_, th)) :: (facts as (_, th') :: _)) = let val chunks = app_hd (cons th) chunks val chunks_and_parents' = if thm_less_eq (th, th') andalso Thm.theory_name th = Thm.theory_name th' then (chunks, [nickname_of_thm th]) else chunks_and_parents_for chunks th' in (parents, fact) :: do_facts chunks_and_parents' facts end in old_facts @ facts |> do_facts (chunks_and_parents_for [[]] th) |> drop (length old_facts) end fun is_fact_in_graph access_G = can (Graph.get_node access_G) o nickname_of_thm val chained_feature_factor = 0.5 (* FUDGE *) val extra_feature_factor = 0.1 (* FUDGE *) val num_extra_feature_facts = 10 (* FUDGE *) val max_proximity_facts = 100 (* FUDGE *) fun find_mash_suggestions ctxt max_facts suggs facts chained raw_unknown = let val inter_fact = inter (eq_snd Thm.eq_thm_prop) val raw_mash = find_suggested_facts ctxt facts suggs val proximate = take max_proximity_facts facts val unknown_chained = inter_fact raw_unknown chained val unknown_proximate = inter_fact raw_unknown proximate val mess = [(0.9 (* FUDGE *), (map (rpair 1.0) unknown_chained, [])), (0.4 (* FUDGE *), (weight_facts_smoothly unknown_proximate, [])), (0.1 (* FUDGE *), (weight_facts_steeply raw_mash, raw_unknown))] val unknown = raw_unknown |> fold (subtract (eq_snd Thm.eq_thm_prop)) [unknown_chained, unknown_proximate] in (mesh_facts (fact_distinct (op aconv)) (eq_snd (gen_eq_thm ctxt)) max_facts mess, unknown) end fun mash_suggested_facts ctxt thy_name ({debug, ...} : params) max_suggs hyp_ts concl_t facts = let val algorithm = the_mash_algorithm () val facts = facts |> rev_sort_list_prefix (crude_thm_ord ctxt o apply2 snd) (Int.max (num_extra_feature_facts, max_proximity_facts)) val chained = filter (fn ((_, (scope, _)), _) => scope = Chained) facts fun fact_has_right_theory (_, th) = thy_name = Thm.theory_name th fun chained_or_extra_features_of factor (((_, stature), th), weight) = [Thm.prop_of th] |> features_of ctxt (Thm.theory_name th) stature |> map (rpair (weight * factor)) in (case get_state ctxt of NONE => ([], []) | SOME {access_G, xtabs = ((num_facts, fact_tab), (num_feats, feat_tab)), ffds, freqs, ...} => let val goal_feats0 = features_of ctxt thy_name (Local, General) (concl_t :: hyp_ts) val chained_feats = chained |> map (rpair 1.0) |> map (chained_or_extra_features_of chained_feature_factor) |> rpair [] |-> fold (union (eq_fst (op =))) val extra_feats = facts |> take (Int.max (0, num_extra_feature_facts - length chained)) |> filter fact_has_right_theory |> weight_facts_steeply |> map (chained_or_extra_features_of extra_feature_factor) |> rpair [] |-> fold (union (eq_fst (op =))) val goal_feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] (map (rpair 1.0) goal_feats0) |> debug ? sort (Real.compare o swap o apply2 snd) val fact_idxs = map_filter (Symtab.lookup fact_tab o nickname_of_thm o snd) facts val suggs = if algorithm = MaSh_NB_Ext orelse algorithm = MaSh_kNN_Ext then let val learns = Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G in MaSh.query_external ctxt algorithm max_suggs learns goal_feats end else let val int_goal_feats = map_filter (fn (s, w) => Option.map (rpair w) (Symtab.lookup feat_tab s)) goal_feats in MaSh.query_internal ctxt algorithm num_facts num_feats ffds freqs fact_idxs max_suggs goal_feats int_goal_feats end val unknown = filter_out (is_fact_in_graph access_G o snd) facts in find_mash_suggestions ctxt max_suggs suggs facts chained unknown - |> apply2 (map fact_of_raw_fact) + |> apply2 (map fact_of_lazy_fact) end) end fun mash_unlearn ctxt = (clear_state ctxt; writeln "Reset MaSh") fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (accum as (access_G, (fact_xtab, feat_xtab))) = let fun maybe_learn_from from (accum as (parents, access_G)) = try_graph ctxt "updating graph" accum (fn () => (from :: parents, Graph.add_edge_acyclic (from, name) access_G)) val access_G = access_G |> Graph.default_node (name, (Isar_Proof, feats, deps)) val (parents, access_G) = ([], access_G) |> fold maybe_learn_from parents val (deps, _) = ([], access_G) |> fold maybe_learn_from deps val fact_xtab = add_to_xtab name fact_xtab val feat_xtab = fold maybe_add_to_xtab feats feat_xtab in (SOME (name, parents, feats, deps), (access_G, (fact_xtab, feat_xtab))) end handle Symtab.DUP _ => (NONE, accum) (* facts sometimes have the same name, confusingly *) fun relearn_wrt_access_graph ctxt (name, deps) access_G = let fun maybe_relearn_from from (accum as (parents, access_G)) = try_graph ctxt "updating graph" accum (fn () => (from :: parents, Graph.add_edge_acyclic (from, name) access_G)) val access_G = access_G |> Graph.map_node name (fn (_, feats, _) => (Automatic_Proof, feats, deps)) val (deps, _) = ([], access_G) |> fold maybe_relearn_from deps in ((name, deps), access_G) end fun flop_wrt_access_graph name = Graph.map_node name (fn (_, feats, deps) => (Isar_Proof_wegen_Prover_Flop, feats, deps)) val learn_timeout_slack = 20.0 fun launch_thread timeout task = let val hard_timeout = Time.scale learn_timeout_slack timeout val birth_time = Time.now () val death_time = birth_time + Timeout.scale_time hard_timeout val desc = ("Machine learner for Sledgehammer", "") in Async_Manager_Legacy.thread MaShN birth_time death_time desc task end fun anonymous_proof_name () = Date.fmt (anonymous_proof_prefix ^ "%Y%m%d.%H%M%S.") (Date.fromTimeLocal (Time.now ())) ^ serial_string () fun mash_learn_proof ctxt ({timeout, ...} : params) t used_ths = if not (null used_ths) andalso is_mash_enabled () then launch_thread timeout (fn () => let val thy = Proof_Context.theory_of ctxt val feats = features_of ctxt (Context.theory_name thy) (Local, General) [t] in map_state ctxt (fn {access_G, xtabs as ((num_facts0, _), _), ffds, freqs, dirty_facts} => let val deps = used_ths |> filter (is_fact_in_graph access_G) |> map nickname_of_thm val name = anonymous_proof_name () val (access_G', xtabs', rev_learns) = add_node Automatic_Proof name [] (* ignore parents *) feats deps (access_G, xtabs, []) val (ffds', freqs') = recompute_ffds_freqs_from_learns (rev rev_learns) xtabs' num_facts0 ffds freqs in {access_G = access_G', xtabs = xtabs', ffds = ffds', freqs = freqs', dirty_facts = Option.map (cons name) dirty_facts} end); (true, "") end) else () fun sendback sub = Active.sendback_markup_command (sledgehammerN ^ " " ^ sub) val commit_timeout = seconds 30.0 (* The timeout is understood in a very relaxed fashion. *) fun mash_learn_facts ctxt (params as {debug, verbose, ...}) prover auto_level run_prover learn_timeout facts = let val timer = Timer.startRealTimer () fun next_commit_time () = Timer.checkRealTimer timer + commit_timeout in (case get_state ctxt of NONE => "MaSh is busy\nPlease try again later" | SOME {access_G, ...} => let val is_in_access_G = is_fact_in_graph access_G o snd val no_new_facts = forall is_in_access_G facts in if no_new_facts andalso not run_prover then if auto_level < 2 then "No new " ^ (if run_prover then "automatic" else "Isar") ^ " proofs to learn" ^ (if auto_level = 0 andalso not run_prover then "\n\nHint: Try " ^ sendback learn_proverN ^ " to learn from an automatic prover" else "") else "" else let val name_tabs = build_name_tables nickname_of_thm facts fun deps_of status th = if status = Non_Rec_Def orelse status = Rec_Def then SOME [] else if run_prover then prover_dependencies_of ctxt params prover auto_level facts name_tabs th |> (fn (false, _) => NONE | (true, deps) => trim_dependencies deps) else isar_dependencies_of name_tabs th fun do_commit [] [] [] state = state | do_commit learns relearns flops {access_G, xtabs as ((num_facts0, _), _), ffds, freqs, dirty_facts} = let val was_empty = Graph.is_empty access_G val (learns, (access_G', xtabs')) = fold_map (learn_wrt_access_graph ctxt) learns (access_G, xtabs) |>> map_filter I val (relearns, access_G'') = fold_map (relearn_wrt_access_graph ctxt) relearns access_G' val access_G''' = access_G'' |> fold flop_wrt_access_graph flops val dirty_facts' = (case (was_empty, dirty_facts) of (false, SOME names) => SOME (map #1 learns @ map #1 relearns @ names) | _ => NONE) val (ffds', freqs') = if null relearns then recompute_ffds_freqs_from_learns (map (fn (name, _, feats, deps) => (name, feats, deps)) learns) xtabs' num_facts0 ffds freqs else recompute_ffds_freqs_from_access_G access_G''' xtabs' in {access_G = access_G''', xtabs = xtabs', ffds = ffds', freqs = freqs', dirty_facts = dirty_facts'} end fun commit last learns relearns flops = (if debug andalso auto_level = 0 then writeln "Committing..." else (); map_state ctxt (do_commit (rev learns) relearns flops); if not last andalso auto_level = 0 then let val num_proofs = length learns + length relearns in writeln ("Learned " ^ string_of_int num_proofs ^ " " ^ (if run_prover then "automatic" else "Isar") ^ " proof" ^ plural_s num_proofs ^ " in the last " ^ string_of_time commit_timeout) end else ()) fun learn_new_fact _ (accum as (_, (_, _, true))) = accum | learn_new_fact (parents, ((_, stature as (_, status)), th)) (learns, (num_nontrivial, next_commit, _)) = let val name = nickname_of_thm th val feats = features_of ctxt (Thm.theory_name th) stature [Thm.prop_of th] val deps = these (deps_of status th) val num_nontrivial = num_nontrivial |> not (null deps) ? Integer.add 1 val learns = (name, parents, feats, deps) :: learns val (learns, next_commit) = if Timer.checkRealTimer timer > next_commit then (commit false learns [] []; ([], next_commit_time ())) else (learns, next_commit) val timed_out = Timer.checkRealTimer timer > learn_timeout in (learns, (num_nontrivial, next_commit, timed_out)) end val (num_new_facts, num_nontrivial) = if no_new_facts then (0, 0) else let val new_facts = facts |> sort (crude_thm_ord ctxt o apply2 snd) |> map (pair []) (* ignore parents *) |> filter_out (is_in_access_G o snd) val (learns, (num_nontrivial, _, _)) = ([], (0, next_commit_time (), false)) |> fold learn_new_fact new_facts in commit true learns [] []; (length new_facts, num_nontrivial) end fun relearn_old_fact _ (accum as (_, (_, _, true))) = accum | relearn_old_fact ((_, (_, status)), th) ((relearns, flops), (num_nontrivial, next_commit, _)) = let val name = nickname_of_thm th val (num_nontrivial, relearns, flops) = (case deps_of status th of SOME deps => (num_nontrivial + 1, (name, deps) :: relearns, flops) | NONE => (num_nontrivial, relearns, name :: flops)) val (relearns, flops, next_commit) = if Timer.checkRealTimer timer > next_commit then (commit false [] relearns flops; ([], [], next_commit_time ())) else (relearns, flops, next_commit) val timed_out = Timer.checkRealTimer timer > learn_timeout in ((relearns, flops), (num_nontrivial, next_commit, timed_out)) end val num_nontrivial = if not run_prover then num_nontrivial else let val max_isar = 1000 * max_dependencies fun priority_of th = Random.random_range 0 max_isar + (case try (Graph.get_node access_G) (nickname_of_thm th) of SOME (Isar_Proof, _, deps) => ~100 * length deps | SOME (Automatic_Proof, _, _) => 2 * max_isar | SOME (Isar_Proof_wegen_Prover_Flop, _, _) => max_isar | NONE => 0) val old_facts = facts |> filter is_in_access_G |> map (`(priority_of o snd)) |> sort (int_ord o apply2 fst) |> map snd val ((relearns, flops), (num_nontrivial, _, _)) = (([], []), (num_nontrivial, next_commit_time (), false)) |> fold relearn_old_fact old_facts in commit true [] relearns flops; num_nontrivial end in if verbose orelse auto_level < 2 then "Learned " ^ string_of_int num_new_facts ^ " fact" ^ plural_s num_new_facts ^ " and " ^ string_of_int num_nontrivial ^ " nontrivial " ^ (if run_prover then "automatic and " else "") ^ "Isar proof" ^ plural_s num_nontrivial ^ (if verbose then " in " ^ string_of_time (Timer.checkRealTimer timer) else "") else "" end end) end fun mash_learn ctxt (params as {provers, timeout, induction_rules, ...}) fact_override chained run_prover = let val css = Sledgehammer_Fact.clasimpset_rule_table_of ctxt val facts = nearly_all_facts ctxt (induction_rules = SOME Instantiate) fact_override Keyword.empty_keywords css chained [] \<^prop>\True\ |> sort (crude_thm_ord ctxt o apply2 snd o swap) val num_facts = length facts val prover = hd provers fun learn auto_level run_prover = mash_learn_facts ctxt params prover auto_level run_prover one_year facts |> writeln in if run_prover then (writeln ("MaShing through " ^ string_of_int num_facts ^ " fact" ^ plural_s num_facts ^ " for automatic proofs (" ^ quote prover ^ " timeout: " ^ string_of_time timeout ^ ").\n\nCollecting Isar proofs first..."); learn 1 false; writeln "Now collecting automatic proofs\n\ \This may take several hours; you can safely stop the learning process at any point"; learn 0 true) else (writeln ("MaShing through " ^ string_of_int num_facts ^ " fact" ^ plural_s num_facts ^ " for Isar proofs..."); learn 0 false) end fun mash_can_suggest_facts ctxt = (case get_state ctxt of NONE => false | SOME {access_G, ...} => not (Graph.is_empty access_G)) fun mash_can_suggest_facts_fast ctxt = (case peek_state ctxt of NONE => false | SOME (_, {access_G, ...}) => not (Graph.is_empty access_G)) (* Generate more suggestions than requested, because some might be thrown out later for various reasons (e.g., duplicates). *) fun generous_max_suggestions max_facts = 2 * max_facts + 25 (* FUDGE *) val mepo_weight = 0.5 (* FUDGE *) val mash_weight = 0.5 (* FUDGE *) val max_facts_to_learn_before_query = 100 (* FUDGE *) (* The threshold should be large enough so that MaSh does not get activated for Auto Sledgehammer. *) val min_secs_for_learning = 10 fun relevant_facts ctxt (params as {verbose, learn, fact_filter, timeout, ...}) prover max_facts ({add, only, ...} : fact_override) hyp_ts concl_t facts = if not (subset (op =) (the_list fact_filter, fact_filters)) then error ("Unknown fact filter: " ^ quote (the fact_filter)) else if only then - [("", map fact_of_raw_fact facts)] + [("", map fact_of_lazy_fact facts)] else if max_facts <= 0 orelse null facts then [("", [])] else let val thy_name = Context.theory_name (Proof_Context.theory_of ctxt) fun maybe_launch_thread exact min_num_facts_to_learn = if not (Async_Manager_Legacy.has_running_threads MaShN) andalso Time.toSeconds timeout >= min_secs_for_learning then let val timeout = Time.scale learn_timeout_slack timeout in (if verbose then writeln ("Started MaShing through " ^ (if exact then "" else "up to ") ^ string_of_int min_num_facts_to_learn ^ " fact" ^ plural_s min_num_facts_to_learn ^ " in the background") else ()); launch_thread timeout (fn () => (true, mash_learn_facts ctxt params prover 2 false timeout facts)) end else () val mash_enabled = is_mash_enabled () val mash_fast = mash_can_suggest_facts_fast ctxt fun please_learn () = if mash_fast then (case get_state ctxt of NONE => maybe_launch_thread false (length facts) | SOME {access_G, xtabs = ((num_facts0, _), _), ...} => let val is_in_access_G = is_fact_in_graph access_G o snd val min_num_facts_to_learn = length facts - num_facts0 in if min_num_facts_to_learn <= max_facts_to_learn_before_query then (case length (filter_out is_in_access_G facts) of 0 => () | num_facts_to_learn => if num_facts_to_learn <= max_facts_to_learn_before_query then mash_learn_facts ctxt params prover 2 false timeout facts |> (fn "" => () | s => writeln (MaShN ^ ": " ^ s)) else maybe_launch_thread true num_facts_to_learn) else maybe_launch_thread false min_num_facts_to_learn end) else maybe_launch_thread false (length facts) val _ = if learn andalso mash_enabled andalso fact_filter <> SOME mepoN then please_learn () else () val effective_fact_filter = (case fact_filter of SOME ff => ff | NONE => if mash_enabled andalso mash_fast then meshN else mepoN) val unique_facts = drop_duplicate_facts facts val add_ths = Attrib.eval_thms ctxt add fun in_add (_, th) = member Thm.eq_thm_prop add_ths th fun add_and_take accepts = (case add_ths of [] => accepts | _ => - (unique_facts |> filter in_add |> map fact_of_raw_fact) @ (accepts |> filter_out in_add)) + (unique_facts |> filter in_add |> map fact_of_lazy_fact) + @ (accepts |> filter_out in_add)) |> take max_facts fun mepo () = (mepo_suggested_facts ctxt params max_facts NONE hyp_ts concl_t unique_facts |> weight_facts_steeply, []) fun mash () = mash_suggested_facts ctxt thy_name params (generous_max_suggestions max_facts) hyp_ts concl_t facts |>> weight_facts_steeply val mess = (* the order is important for the "case" expression below *) [] |> effective_fact_filter <> mepoN ? cons (mash_weight, mash) |> effective_fact_filter <> mashN ? cons (mepo_weight, mepo) |> Par_List.map (apsnd (fn f => f ())) val mesh = mesh_facts (fact_distinct (op aconv)) (eq_snd (gen_eq_thm ctxt)) max_facts mess |> add_and_take in (case (fact_filter, mess) of (NONE, [(_, (mepo, _)), (_, (mash, _))]) => [(meshN, mesh), (mepoN, mepo |> map fst |> add_and_take), (mashN, mash |> map fst |> add_and_take)] | _ => [(effective_fact_filter, mesh)]) end end; diff --git a/src/HOL/Tools/Sledgehammer/sledgehammer_mepo.ML b/src/HOL/Tools/Sledgehammer/sledgehammer_mepo.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mepo.ML +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mepo.ML @@ -1,553 +1,553 @@ (* Title: HOL/Tools/Sledgehammer/sledgehammer_mepo.ML Author: Jia Meng, Cambridge University Computer Laboratory and NICTA Author: Jasmin Blanchette, TU Muenchen Sledgehammer's iterative relevance filter (MePo = Meng-Paulson). *) signature SLEDGEHAMMER_MEPO = sig type stature = ATP_Problem_Generate.stature - type raw_fact = Sledgehammer_Fact.raw_fact + type lazy_fact = Sledgehammer_Fact.lazy_fact type fact = Sledgehammer_Fact.fact type params = Sledgehammer_Prover.params type relevance_fudge = {local_const_multiplier : real, worse_irrel_freq : real, higher_order_irrel_weight : real, abs_rel_weight : real, abs_irrel_weight : real, theory_const_rel_weight : real, theory_const_irrel_weight : real, chained_const_irrel_weight : real, intro_bonus : real, elim_bonus : real, simp_bonus : real, local_bonus : real, assum_bonus : real, chained_bonus : real, max_imperfect : real, max_imperfect_exp : real, threshold_divisor : real, ridiculous_threshold : real} val trace : bool Config.T val pseudo_abs_name : string val default_relevance_fudge : relevance_fudge val mepo_suggested_facts : Proof.context -> params -> int -> relevance_fudge option -> - term list -> term -> raw_fact list -> fact list + term list -> term -> lazy_fact list -> fact list end; structure Sledgehammer_MePo : SLEDGEHAMMER_MEPO = struct open ATP_Problem_Generate open Sledgehammer_Util open Sledgehammer_Fact open Sledgehammer_Prover val trace = Attrib.setup_config_bool \<^binding>\sledgehammer_mepo_trace\ (K false) fun trace_msg ctxt msg = if Config.get ctxt trace then tracing (msg ()) else () val sledgehammer_prefix = "Sledgehammer" ^ Long_Name.separator val pseudo_abs_name = sledgehammer_prefix ^ "abs" val theory_const_suffix = Long_Name.separator ^ " 1" type relevance_fudge = {local_const_multiplier : real, worse_irrel_freq : real, higher_order_irrel_weight : real, abs_rel_weight : real, abs_irrel_weight : real, theory_const_rel_weight : real, theory_const_irrel_weight : real, chained_const_irrel_weight : real, intro_bonus : real, elim_bonus : real, simp_bonus : real, local_bonus : real, assum_bonus : real, chained_bonus : real, max_imperfect : real, max_imperfect_exp : real, threshold_divisor : real, ridiculous_threshold : real} (* FUDGE *) val default_relevance_fudge = {local_const_multiplier = 1.5, worse_irrel_freq = 100.0, higher_order_irrel_weight = 1.05, abs_rel_weight = 0.5, abs_irrel_weight = 2.0, theory_const_rel_weight = 0.5, theory_const_irrel_weight = 0.25, chained_const_irrel_weight = 0.25, intro_bonus = 0.15, elim_bonus = 0.15, simp_bonus = 0.15, local_bonus = 0.55, assum_bonus = 1.05, chained_bonus = 1.5, max_imperfect = 11.5, max_imperfect_exp = 1.0, threshold_divisor = 2.0, ridiculous_threshold = 0.1} fun order_of_type (Type (\<^type_name>\fun\, [T1, T2])) = Int.max (order_of_type T1 + 1, order_of_type T2) | order_of_type (Type (_, Ts)) = fold (Integer.max o order_of_type) Ts 0 | order_of_type _ = 0 (* An abstraction of Isabelle types and first-order terms *) datatype pattern = PVar | PApp of string * pattern list datatype ptype = PType of int * typ list fun string_of_patternT (TVar _) = "_" | string_of_patternT (Type (s, ps)) = if null ps then s else s ^ string_of_patternsT ps | string_of_patternT (TFree (s, _)) = s and string_of_patternsT ps = "(" ^ commas (map string_of_patternT ps) ^ ")" fun string_of_ptype (PType (_, ps)) = string_of_patternsT ps (*Is the second type an instance of the first one?*) fun match_patternT (TVar _, _) = true | match_patternT (Type (s, ps), Type (t, qs)) = s = t andalso match_patternsT (ps, qs) | match_patternT (TFree (s, _), TFree (t, _)) = s = t | match_patternT (_, _) = false and match_patternsT (_, []) = true | match_patternsT ([], _) = false | match_patternsT (p :: ps, q :: qs) = match_patternT (p, q) andalso match_patternsT (ps, qs) fun match_ptype (PType (_, ps), PType (_, qs)) = match_patternsT (ps, qs) (* Is there a unifiable constant? *) fun pconst_mem f consts (s, ps) = exists (curry (match_ptype o f) ps) (map snd (filter (curry (op =) s o fst) consts)) fun pconst_hyper_mem f const_tab (s, ps) = exists (curry (match_ptype o f) ps) (these (Symtab.lookup const_tab s)) (* Pairs a constant with the list of its type instantiations. *) fun ptype thy const x = (if const then these (try (Sign.const_typargs thy) x) else []) fun rich_ptype thy const (s, T) = PType (order_of_type T, ptype thy const (s, T)) fun rich_pconst thy const (s, T) = (s, rich_ptype thy const (s, T)) fun string_of_hyper_pconst (s, ps) = s ^ "{" ^ commas (map string_of_ptype ps) ^ "}" fun patternT_eq (TVar _, TVar _) = true | patternT_eq (Type (s, Ts), Type (t, Us)) = s = t andalso patternsT_eq (Ts, Us) | patternT_eq (TFree (s, _), TFree (t, _)) = (s = t) | patternT_eq _ = false and patternsT_eq ([], []) = true | patternsT_eq ([], _) = false | patternsT_eq (_, []) = false | patternsT_eq (T :: Ts, U :: Us) = patternT_eq (T, U) andalso patternsT_eq (Ts, Us) fun ptype_eq (PType (m, Ts), PType (n, Us)) = m = n andalso patternsT_eq (Ts, Us) (* Add a pconstant to the table, but a [] entry means a standard connective, which we ignore. *) fun add_pconst_to_table (s, p) = Symtab.map_default (s, [p]) (insert ptype_eq p) (* Set constants tend to pull in too many irrelevant facts. We limit the damage by treating them more or less as if they were built-in but add their axiomatization at the end. *) val set_consts = [\<^const_name>\Collect\, \<^const_name>\Set.member\] val set_thms = @{thms Collect_mem_eq mem_Collect_eq Collect_cong} fun add_pconsts_in_term thy = let fun do_const const (x as (s, _)) ts = if member (op =) set_consts s then fold (do_term false) ts else (not (is_irrelevant_const s) ? add_pconst_to_table (rich_pconst thy const x)) #> fold (do_term false) ts and do_term ext_arg t = (case strip_comb t of (Const x, ts) => do_const true x ts | (Free x, ts) => do_const false x ts | (Abs (_, T, t'), ts) => ((null ts andalso not ext_arg) (* Since lambdas on the right-hand side of equalities are usually extensionalized later by "abs_extensionalize_term", we don't penalize them here. *) ? add_pconst_to_table (pseudo_abs_name, PType (order_of_type T + 1, []))) #> fold (do_term false) (t' :: ts) | (_, ts) => fold (do_term false) ts) and do_term_or_formula ext_arg T = if T = HOLogic.boolT then do_formula else do_term ext_arg and do_formula t = (case t of Const (\<^const_name>\Pure.all\, _) $ Abs (_, _, t') => do_formula t' | \<^const>\Pure.imp\ $ t1 $ t2 => do_formula t1 #> do_formula t2 | Const (\<^const_name>\Pure.eq\, Type (_, [T, _])) $ t1 $ t2 => do_term_or_formula false T t1 #> do_term_or_formula true T t2 | \<^const>\Trueprop\ $ t1 => do_formula t1 | \<^const>\False\ => I | \<^const>\True\ => I | \<^const>\Not\ $ t1 => do_formula t1 | Const (\<^const_name>\All\, _) $ Abs (_, _, t') => do_formula t' | Const (\<^const_name>\Ex\, _) $ Abs (_, _, t') => do_formula t' | \<^const>\HOL.conj\ $ t1 $ t2 => do_formula t1 #> do_formula t2 | \<^const>\HOL.disj\ $ t1 $ t2 => do_formula t1 #> do_formula t2 | \<^const>\HOL.implies\ $ t1 $ t2 => do_formula t1 #> do_formula t2 | Const (\<^const_name>\HOL.eq\, Type (_, [T, _])) $ t1 $ t2 => do_term_or_formula false T t1 #> do_term_or_formula true T t2 | Const (\<^const_name>\If\, Type (_, [_, Type (_, [T, _])])) $ t1 $ t2 $ t3 => do_formula t1 #> fold (do_term_or_formula false T) [t2, t3] | Const (\<^const_name>\Ex1\, _) $ Abs (_, _, t') => do_formula t' | Const (\<^const_name>\Ball\, _) $ t1 $ Abs (_, _, t') => do_formula (t1 $ Bound ~1) #> do_formula t' | Const (\<^const_name>\Bex\, _) $ t1 $ Abs (_, _, t') => do_formula (t1 $ Bound ~1) #> do_formula t' | (t0 as Const (_, \<^typ>\bool\)) $ t1 => do_term false t0 #> do_formula t1 (* theory constant *) | _ => do_term false t) in do_formula end fun pconsts_in_fact thy t = Symtab.fold (fn (s, pss) => fold (cons o pair s) pss) (Symtab.empty |> add_pconsts_in_term thy t) [] (* Inserts a dummy "constant" referring to the theory name, so that relevance takes the given theory into account. *) fun theory_constify ({theory_const_rel_weight, theory_const_irrel_weight, ...} : relevance_fudge) thy_name t = if exists (curry (op <) 0.0) [theory_const_rel_weight, theory_const_irrel_weight] then Const (thy_name ^ theory_const_suffix, \<^typ>\bool\) $ t else t fun theory_const_prop_of fudge th = theory_constify fudge (Thm.theory_name th) (Thm.prop_of th) fun pair_consts_fact thy fudge fact = (case fact |> snd |> theory_const_prop_of fudge |> pconsts_in_fact thy of [] => NONE | consts => SOME ((fact, consts), NONE)) (* A two-dimensional symbol table counts frequencies of constants. It's keyed first by constant name and second by its list of type instantiations. For the latter, we need a linear ordering on "pattern list". *) fun patternT_ord p = (case p of (Type (s, ps), Type (t, qs)) => (case fast_string_ord (s, t) of EQUAL => dict_ord patternT_ord (ps, qs) | ord => ord) | (TVar _, TVar _) => EQUAL | (TVar _, _) => LESS | (Type _, TVar _) => GREATER | (Type _, TFree _) => LESS | (TFree (s, _), TFree (t, _)) => fast_string_ord (s, t) | (TFree _, _) => GREATER) fun ptype_ord (PType (m, ps), PType (n, qs)) = (case dict_ord patternT_ord (ps, qs) of EQUAL => int_ord (m, n) | ord => ord) structure PType_Tab = Table(type key = ptype val ord = ptype_ord) fun count_fact_consts thy fudge = let fun do_const const (s, T) ts = (* Two-dimensional table update. Constant maps to types maps to count. *) PType_Tab.map_default (rich_ptype thy const (s, T), 0) (Integer.add 1) |> Symtab.map_default (s, PType_Tab.empty) #> fold do_term ts and do_term t = (case strip_comb t of (Const x, ts) => do_const true x ts | (Free x, ts) => do_const false x ts | (Abs (_, _, t'), ts) => fold do_term (t' :: ts) | (_, ts) => fold do_term ts) in do_term o theory_const_prop_of fudge o snd end fun pow_int _ 0 = 1.0 | pow_int x 1 = x | pow_int x n = if n > 0 then x * pow_int x (n - 1) else pow_int x (n + 1) / x (*The frequency of a constant is the sum of those of all instances of its type.*) fun pconst_freq match const_tab (c, ps) = PType_Tab.fold (fn (qs, m) => match (ps, qs) ? Integer.add m) (the (Symtab.lookup const_tab c)) 0 (* A surprising number of theorems contain only a few significant constants. These include all induction rules and other general theorems. *) (* "log" seems best in practice. A constant function of one ignores the constant frequencies. Rare constants give more points if they are relevant than less rare ones. *) fun rel_weight_for _ freq = 1.0 + 2.0 / Math.ln (Real.fromInt freq + 1.0) (* Irrelevant constants are treated differently. We associate lower penalties to very rare constants and very common ones -- the former because they can't lead to the inclusion of too many new facts, and the latter because they are so common as to be of little interest. *) fun irrel_weight_for ({worse_irrel_freq, higher_order_irrel_weight, ...} : relevance_fudge) order freq = let val (k, x) = worse_irrel_freq |> `Real.ceil in (if freq < k then Math.ln (Real.fromInt (freq + 1)) / Math.ln x else rel_weight_for order freq / rel_weight_for order k) * pow_int higher_order_irrel_weight (order - 1) end fun multiplier_of_const_name local_const_multiplier s = if String.isSubstring "." s then 1.0 else local_const_multiplier (* Computes a constant's weight, as determined by its frequency. *) fun generic_pconst_weight local_const_multiplier abs_weight theory_const_weight chained_const_weight weight_for f const_tab chained_const_tab (c as (s, PType (m, _))) = if s = pseudo_abs_name then abs_weight else if String.isSuffix theory_const_suffix s then theory_const_weight else multiplier_of_const_name local_const_multiplier s * weight_for m (pconst_freq (match_ptype o f) const_tab c) |> (if chained_const_weight < 1.0 andalso pconst_hyper_mem I chained_const_tab c then curry (op *) chained_const_weight else I) fun rel_pconst_weight ({local_const_multiplier, abs_rel_weight, theory_const_rel_weight, ...} : relevance_fudge) const_tab = generic_pconst_weight local_const_multiplier abs_rel_weight theory_const_rel_weight 0.0 rel_weight_for I const_tab Symtab.empty fun irrel_pconst_weight (fudge as {local_const_multiplier, abs_irrel_weight, theory_const_irrel_weight, chained_const_irrel_weight, ...}) const_tab chained_const_tab = generic_pconst_weight local_const_multiplier abs_irrel_weight theory_const_irrel_weight chained_const_irrel_weight (irrel_weight_for fudge) swap const_tab chained_const_tab fun stature_bonus ({intro_bonus, ...} : relevance_fudge) (_, Intro) = intro_bonus | stature_bonus {elim_bonus, ...} (_, Elim) = elim_bonus | stature_bonus {simp_bonus, ...} (_, Simp) = simp_bonus | stature_bonus {local_bonus, ...} (Local, _) = local_bonus | stature_bonus {assum_bonus, ...} (Assum, _) = assum_bonus | stature_bonus {chained_bonus, ...} (Chained, _) = chained_bonus | stature_bonus _ _ = 0.0 fun is_odd_const_name s = s = pseudo_abs_name orelse String.isSuffix theory_const_suffix s fun fact_weight fudge stature const_tab rel_const_tab chained_const_tab fact_consts = (case fact_consts |> List.partition (pconst_hyper_mem I rel_const_tab) ||> filter_out (pconst_hyper_mem swap rel_const_tab) of ([], _) => 0.0 | (rel, irrel) => if forall (forall (is_odd_const_name o fst)) [rel, irrel] then 0.0 else let val irrel = irrel |> filter_out (pconst_mem swap rel) val rel_weight = 0.0 |> fold (curry (op +) o rel_pconst_weight fudge const_tab) rel val irrel_weight = ~ (stature_bonus fudge stature) |> fold (curry (op +) o irrel_pconst_weight fudge const_tab chained_const_tab) irrel val res = rel_weight / (rel_weight + irrel_weight) in if Real.isFinite res then res else 0.0 end) fun take_most_relevant ctxt max_facts remaining_max ({max_imperfect, max_imperfect_exp, ...} : relevance_fudge) - (candidates : ((raw_fact * (string * ptype) list) * real) list) = + (candidates : ((lazy_fact * (string * ptype) list) * real) list) = let val max_imperfect = Real.ceil (Math.pow (max_imperfect, Math.pow (Real.fromInt remaining_max / Real.fromInt max_facts, max_imperfect_exp))) val (perfect, imperfect) = candidates |> sort (Real.compare o swap o apply2 snd) |> chop_prefix (fn (_, w) => w > 0.99999) val ((accepts, more_rejects), rejects) = chop max_imperfect imperfect |>> append perfect |>> chop remaining_max in trace_msg ctxt (fn () => "Actually passed (" ^ string_of_int (length accepts) ^ " of " ^ string_of_int (length candidates) ^ "): " ^ (accepts |> map (fn ((((name, _), _), _), weight) => name () ^ " [" ^ Real.toString weight ^ "]") |> commas)); (accepts, more_rejects @ rejects) end fun if_empty_replace_with_scope thy facts sc tab = if Symtab.is_empty tab then Symtab.empty |> fold (add_pconsts_in_term thy) (map_filter (fn ((_, (sc', _)), th) => if sc' = sc then SOME (Thm.prop_of th) else NONE) facts) else tab fun consider_arities th = let fun aux _ _ NONE = NONE | aux t args (SOME tab) = (case t of t1 $ t2 => SOME tab |> aux t1 (t2 :: args) |> aux t2 [] | Const (s, _) => (if is_widely_irrelevant_const s then SOME tab else (case Symtab.lookup tab s of NONE => SOME (Symtab.update (s, length args) tab) | SOME n => if n = length args then SOME tab else NONE)) | _ => SOME tab) in aux (Thm.prop_of th) [] end (* FIXME: This is currently only useful for polymorphic type encodings. *) fun could_benefit_from_ext facts = fold (consider_arities o snd) facts (SOME Symtab.empty) |> is_none (* High enough so that it isn't wrongly considered as very relevant (e.g., for E weights), but low enough so that it is unlikely to be truncated away if few facts are included. *) val special_fact_index = 45 (* FUDGE *) fun eq_prod eqx eqy ((x1, y1), (x2, y2)) = eqx (x1, x2) andalso eqy (y1, y2) val really_hopeless_get_kicked_out_iter = 5 (* FUDGE *) fun relevance_filter ctxt thres0 decay max_facts (fudge as {threshold_divisor, ridiculous_threshold, ...}) facts hyp_ts concl_t = let val thy = Proof_Context.theory_of ctxt val const_tab = fold (count_fact_consts thy fudge) facts Symtab.empty val add_pconsts = add_pconsts_in_term thy val chained_ts = facts |> map_filter (try (fn ((_, (Chained, _)), th) => Thm.prop_of th)) val chained_const_tab = Symtab.empty |> fold add_pconsts chained_ts val goal_const_tab = Symtab.empty |> fold add_pconsts hyp_ts |> add_pconsts concl_t |> (fn tab => if Symtab.is_empty tab then chained_const_tab else tab) |> fold (if_empty_replace_with_scope thy facts) [Chained, Assum, Local] fun iter j remaining_max thres rel_const_tab hopeless hopeful = let val hopeless = hopeless |> j = really_hopeless_get_kicked_out_iter ? filter_out (fn (_, w) => w < 0.001) fun relevant [] _ [] = (* Nothing has been added this iteration. *) if j = 0 andalso thres >= ridiculous_threshold then (* First iteration? Try again. *) iter 0 max_facts (thres / threshold_divisor) rel_const_tab hopeless hopeful else [] | relevant candidates rejects [] = let val (accepts, more_rejects) = take_most_relevant ctxt max_facts remaining_max fudge candidates val sps = maps (snd o fst) accepts val rel_const_tab' = rel_const_tab |> fold add_pconst_to_table sps fun is_dirty (s, _) = Symtab.lookup rel_const_tab' s <> Symtab.lookup rel_const_tab s val (hopeful_rejects, hopeless_rejects) = (rejects @ hopeless, ([], [])) |-> fold (fn (ax as (_, consts), old_weight) => if exists is_dirty consts then apfst (cons (ax, NONE)) else apsnd (cons (ax, old_weight))) |>> append (more_rejects |> map (fn (ax as (_, consts), old_weight) => (ax, if exists is_dirty consts then NONE else SOME old_weight))) val thres = 1.0 - (1.0 - thres) * Math.pow (decay, Real.fromInt (length accepts)) val remaining_max = remaining_max - length accepts in trace_msg ctxt (fn () => "New or updated constants: " ^ commas (rel_const_tab' |> Symtab.dest |> subtract (eq_prod (op =) (eq_list ptype_eq)) (Symtab.dest rel_const_tab) |> map string_of_hyper_pconst)); map (fst o fst) accepts @ (if remaining_max = 0 then [] else iter (j + 1) remaining_max thres rel_const_tab' hopeless_rejects hopeful_rejects) end | relevant candidates rejects (((ax as (((_, stature), _), fact_consts)), cached_weight) :: hopeful) = let val weight = (case cached_weight of SOME w => w | NONE => fact_weight fudge stature const_tab rel_const_tab chained_const_tab fact_consts) in if weight >= thres then relevant ((ax, weight) :: candidates) rejects hopeful else relevant candidates ((ax, weight) :: rejects) hopeful end in trace_msg ctxt (fn () => "ITERATION " ^ string_of_int j ^ ": current threshold: " ^ Real.toString thres ^ ", constants: " ^ commas (rel_const_tab |> Symtab.dest |> filter (curry (op <>) [] o snd) |> map string_of_hyper_pconst)); relevant [] [] hopeful end fun uses_const s t = fold_aterms (curry (fn (Const (s', _), false) => s' = s | (_, b) => b)) t false fun uses_const_anywhere accepts s = exists (uses_const s o Thm.prop_of o snd) accepts orelse exists (uses_const s) (concl_t :: hyp_ts) fun add_set_const_thms accepts = exists (uses_const_anywhere accepts) set_consts ? append set_thms fun insert_into_facts accepts [] = accepts | insert_into_facts accepts ths = let val add = facts |> filter (member Thm.eq_thm_prop ths o snd) val (bef, after) = accepts |> filter_out (member Thm.eq_thm_prop ths o snd) |> take (max_facts - length add) |> chop special_fact_index in bef @ add @ after end fun insert_special_facts accepts = (* FIXME: get rid of "ext" here once it is treated as a helper *) [] |> could_benefit_from_ext accepts ? cons @{thm ext} |> add_set_const_thms accepts |> insert_into_facts accepts in facts |> map_filter (pair_consts_fact thy fudge) |> iter 0 max_facts thres0 goal_const_tab [] |> insert_special_facts |> tap (fn accepts => trace_msg ctxt (fn () => "Total relevant: " ^ string_of_int (length accepts))) end fun mepo_suggested_facts ctxt ({fact_thresholds = (thres0, thres1), ...} : params) max_facts fudge hyp_ts concl_t facts = let val thy = Proof_Context.theory_of ctxt val fudge = fudge |> the_default default_relevance_fudge val decay = Math.pow ((1.0 - thres1) / (1.0 - thres0), 1.0 / Real.fromInt (max_facts + 1)) in trace_msg ctxt (fn () => "Considering " ^ string_of_int (length facts) ^ " facts"); (if thres1 < 0.0 then facts else if thres0 > 1.0 orelse thres0 > thres1 orelse max_facts <= 0 then [] else relevance_filter ctxt thres0 decay max_facts fudge facts hyp_ts (concl_t |> theory_constify fudge (Context.theory_name thy))) - |> map fact_of_raw_fact + |> map fact_of_lazy_fact end end;