diff --git a/thys/Deep_Learning/DL_Deep_Model_Poly.thy b/thys/Deep_Learning/DL_Deep_Model_Poly.thy --- a/thys/Deep_Learning/DL_Deep_Model_Poly.thy +++ b/thys/Deep_Learning/DL_Deep_Model_Poly.thy @@ -1,365 +1,363 @@ (* Author: Alexander Bentkamp, Universität des Saarlandes *) section \Polynomials representing the Deep Network Model\ theory DL_Deep_Model_Poly imports DL_Deep_Model Polynomials.More_MPoly_Type Jordan_Normal_Form.Determinant begin definition "polyfun N f = (\p. vars p \ N \ (\x. insertion x p = f x))" lemma polyfunI: "(\P. (\p. vars p \ N \ (\x. insertion x p = f x) \ P) \ P) \ polyfun N f" unfolding polyfun_def by metis lemma polyfun_subset: "N\N' \ polyfun N f \ polyfun N' f" unfolding polyfun_def by blast lemma polyfun_const: "polyfun N (\_. c)" proof - have "\x. insertion x (monom 0 c) = c" using insertion_single by (metis insertion_one monom_one mult.commute mult.right_neutral single_zero) then show ?thesis unfolding polyfun_def by (metis (full_types) empty_iff keys_single single_zero subsetI subset_antisym vars_monom_subset) qed lemma polyfun_add: assumes "polyfun N f" "polyfun N g" shows "polyfun N (\x. f x + g x)" proof - obtain p1 p2 where "vars p1 \ N" "\x. insertion x p1 = f x" "vars p2 \ N" "\x. insertion x p2 = g x" using polyfun_def assms by metis then have "vars (p1 + p2) \ N" "\x. insertion x (p1 + p2) = f x + g x" using vars_add using Un_iff subsetCE subsetI apply blast by (simp add: \\x. insertion x p1 = f x\ \\x. insertion x p2 = g x\ insertion_add) then show ?thesis using polyfun_def by blast qed lemma polyfun_mult: assumes "polyfun N f" "polyfun N g" shows "polyfun N (\x. f x * g x)" proof - obtain p1 p2 where "vars p1 \ N" "\x. insertion x p1 = f x" "vars p2 \ N" "\x. insertion x p2 = g x" using polyfun_def assms by metis then have "vars (p1 * p2) \ N" "\x. insertion x (p1 * p2) = f x * g x" using vars_mult using Un_iff subsetCE subsetI apply blast by (simp add: \\x. insertion x p1 = f x\ \\x. insertion x p2 = g x\ insertion_mult) then show ?thesis using polyfun_def by blast qed lemma polyfun_Sum: assumes "finite I" assumes "\i. i\I \ polyfun N (f i)" shows "polyfun N (\x. \i\I. f i x)" using assms apply (induction I rule:finite_induct) apply (simp add: polyfun_const) using comm_monoid_add_class.sum.insert polyfun_add by fastforce lemma polyfun_Prod: assumes "finite I" assumes "\i. i\I \ polyfun N (f i)" shows "polyfun N (\x. \i\I. f i x)" using assms apply (induction I rule:finite_induct) apply (simp add: polyfun_const) using comm_monoid_add_class.sum.insert polyfun_mult by fastforce lemma polyfun_single: assumes "i\N" shows "polyfun N (\x. x i)" proof - have "\f. insertion f (monom (Poly_Mapping.single i 1) 1) = f i" using insertion_single by simp then show ?thesis unfolding polyfun_def using vars_monom_single[of i 1 1] One_nat_def assms singletonD subset_eq by blast qed lemma polyfun_det: assumes "\x. (A x) \ carrier_mat n n" assumes "\x i j. i j polyfun N (\x. (A x) $$ (i,j))" shows "polyfun N (\x. det (A x))" proof - { fix p assume "p\ {p. p permutes {0..x. x < n \ p x < n" using permutes_in_image by auto then have "polyfun N (\x. \i = 0..i x. A x $$ (i, p i)"] assms by simp then have "polyfun N (\x. signof p * (\i = 0..p x. signof p * (\i = 0..f. extract_matrix (\i. f (i + a)) m n $$ (i,j))" unfolding index_extract_matrix[OF assms] apply (rule polyfun_single) using two_digit_le[OF assms] by simp lemma polyfun_mult_mat_vec: assumes "\x. v x \ carrier_vec n" assumes "\j. j polyfun N (\x. v x $ j)" assumes "\x. A x \ carrier_mat m n" assumes "\i j. i j polyfun N (\x. A x $$ (i,j))" assumes "j < m" shows "polyfun N (\x. ((A x) *\<^sub>v (v x)) $ j)" proof - have "\x. j < dim_row (A x)" using `j < m` assms(3) carrier_matD(1) by force have "\x. n = dim_vec (v x)" using assms(1) carrier_vecD by fastforce { fix i assume "i \ {0..i < dim_vec (v x)\) } then have "polyfun N (\x. row (A x) j $ i * v x $ i)" using polyfun_mult assms(4)[OF `j < m`] assms(2) by fastforce } then show ?thesis unfolding index_mult_mat_vec[OF `\x. j < dim_row (A x)`] scalar_prod_def using polyfun_Sum[of "{0..i x. row (A x) j $ i * v x $ i"] finite_atLeastLessThan[of 0 n] `\x. n = dim_vec (v x)` by simp qed (* The variable a has been inserted here to make the induction work:*) lemma polyfun_evaluate_net_plus_a: assumes "map dim_vec inputs = input_sizes m" assumes "valid_net m" assumes "j < output_size m" shows "polyfun {..f. evaluate_net (insert_weights s m (\i. f (i + a))) inputs $ j)" using assms proof (induction m arbitrary:inputs j a) case (Input) then show ?case unfolding insert_weights.simps evaluate_net.simps using polyfun_const by metis next case (Conv x m) then obtain x1 x2 where "x=(x1,x2)" by fastforce show ?case unfolding `x=(x1,x2)` insert_weights.simps evaluate_net.simps drop_map unfolding list_of_vec_index proof (rule polyfun_mult_mat_vec) { fix f have 1:"valid_net' (insert_weights s m (\i. f (i + x1 * x2)))" using `valid_net (Conv x m)` valid_net.simps by (metis convnet.distinct(1) convnet.distinct(5) convnet.inject(2) remove_insert_weights) have 2:"map dim_vec inputs = input_sizes (insert_weights s m (\i. f (i + x1 * x2)))" using input_sizes_remove_weights remove_insert_weights by (simp add: Conv.prems(1)) have "dim_vec (evaluate_net (insert_weights s m (\i. f (i + x1 * x2))) inputs) = output_size m" using output_size_correct[OF 1 2] using remove_insert_weights by auto then show "evaluate_net (insert_weights s m (\i. f (i + x1 * x2))) inputs \ carrier_vec (output_size m)" using carrier_vec_def by (metis (full_types) mem_Collect_eq) } have "map dim_vec inputs = input_sizes m" by (simp add: Conv.prems(1)) have "valid_net m" using Conv.prems(2) valid_net.cases by fastforce show "\j. j < output_size m \ polyfun {..f. evaluate_net (insert_weights s m (\i. f (i + x1 * x2 + a))) inputs $ j)" unfolding vec_of_list_index count_weights.simps using Conv(1)[OF `map dim_vec inputs = input_sizes m` `valid_net m`, of _ "x1 * x2 + a"] unfolding semigroup_add_class.add.assoc ab_semigroup_add_class.add.commute[of "x1 * x2" a] by blast have "output_size m = x2" using Conv.prems(2) \x = (x1, x2)\ valid_net.cases by fastforce show "\f. extract_matrix (\i. f (i + a)) x1 x2 \ carrier_mat x1 (output_size m)" unfolding `output_size m = x2` using dim_extract_matrix using carrier_matI by (metis (no_types, lifting)) show "\i j. i < x1 \ j < output_size m \ polyfun {..f. extract_matrix (\i. f (i + a)) x1 x2 $$ (i, j))" unfolding `output_size m = x2` count_weights.simps using polyfun_extract_matrix[of _ x1 _ x2 a "count_weights s m"] by blast show "j < x1" using Conv.prems(3) \x = (x1, x2)\ by auto qed next case (Pool m1 m2 inputs j a) have A2:"\f. map dim_vec (take (length (input_sizes (insert_weights s m1 (\i. f (i + a))))) inputs) = input_sizes m1" by (metis Pool.prems(1) append_eq_conv_conj input_sizes.simps(3) input_sizes_remove_weights remove_insert_weights take_map) have B2:"\f. map dim_vec (drop (length (input_sizes (insert_weights s m1 (\i. f (i + a))))) inputs) = input_sizes m2" using Pool.prems(1) append_eq_conv_conj input_sizes.simps(3) input_sizes_remove_weights remove_insert_weights by (metis drop_map) have A3:"valid_net m1" and B3:"valid_net m2" using `valid_net (Pool m1 m2)` valid_net.simps by blast+ have "output_size (Pool m1 m2) = output_size m2" unfolding output_size.simps using `valid_net (Pool m1 m2)` "valid_net.cases" by fastforce then have A4:"j < output_size m1" and B4:"j < output_size m2" using `j < output_size (Pool m1 m2)` by simp_all let ?net1 = "\f. evaluate_net (insert_weights s m1 (\i. f (i + a))) (take (length (input_sizes (insert_weights s m1 (\i. f (i + a))))) inputs)" let ?net2 = "\f. evaluate_net (insert_weights s m2 (if s then \i. f (i + a) else (\i. f (i + count_weights s m1 + a)))) (drop (length (input_sizes (insert_weights s m1 (\i. f (i + a))))) inputs)" have length1: "\f. output_size m1 = dim_vec (?net1 f)" by (metis A2 A3 input_sizes_remove_weights output_size_correct remove_insert_weights) then have jlength1:"\f. j < dim_vec (?net1 f)" using A4 by metis have length2: "\f. output_size m2 = dim_vec (?net2 f)" by (metis B2 B3 input_sizes_remove_weights output_size_correct remove_insert_weights) then have jlength2:"\f. j < dim_vec (?net2 f)" using B4 by metis have cong1:"\xf. (\f. evaluate_net (insert_weights s m1 (\i. f (i + a))) (take (length (input_sizes (insert_weights s m1 (\i. xf (i + a))))) inputs) $ j) = (\f. ?net1 f $ j)" using input_sizes_remove_weights remove_insert_weights by auto have cong2:"\xf. (\f. evaluate_net (insert_weights s m2 (\i. f (i + (a + (if s then 0 else count_weights s m1))))) (drop (length (input_sizes (insert_weights s m1 (\i. xf (i + a))))) inputs) $ j) = (\f. ?net2 f $ j)" unfolding semigroup_add_class.add.assoc[symmetric] ab_semigroup_add_class.add.commute[of a "if s then 0 else count_weights s m1"] using input_sizes_remove_weights remove_insert_weights by auto show ?case unfolding insert_weights.simps evaluate_net.simps count_weights.simps unfolding index_component_mult[OF jlength1 jlength2] apply (rule polyfun_mult) using Pool.IH(1)[OF A2 A3 A4, of a, unfolded cong1] apply (simp add:polyfun_subset[of "{..f. evaluate_net (insert_weights s m f) inputs $ j)" using polyfun_evaluate_net_plus_a[where a=0, OF assms] by simp lemma polyfun_tensors_from_net: assumes "valid_net m" assumes "is \ input_sizes m" assumes "j < output_size m" shows "polyfun {..f. Tensor.lookup (tensors_from_net (insert_weights s m f) $ j) is)" proof - have 1:"\f. valid_net' (insert_weights s m f)" by (simp add: assms(1) remove_insert_weights) have input_sizes:"\f. input_sizes (insert_weights s m f) = input_sizes m" unfolding input_sizes_remove_weights by (simp add: remove_insert_weights) have 2:"\f. is \ input_sizes (insert_weights s m f)" unfolding input_sizes using assms(2) by blast have 3:"\f. j < output_size' (insert_weights s m f)" by (simp add: assms(3) remove_insert_weights) have "\f1 f2. base_input (insert_weights s m f1) is = base_input (insert_weights s m f2) is" unfolding base_input_def by (simp add: input_sizes) then have "\xf. (\f. evaluate_net (insert_weights s m f) (base_input (insert_weights s m xf) is) $ j) = (\f. evaluate_net (insert_weights s m f) (base_input (insert_weights s m f) is) $ j)" by metis then show ?thesis unfolding lookup_tensors_from_net[OF 1 2 3] using polyfun_evaluate_net[OF base_input_length[OF 2, unfolded input_sizes, symmetric] assms(1) assms(3), of s] -(* TODO: Investigate why sledgehammer fails: - using polyfun_evaluate_net[OF base_input_length[OF 2, unfolded input_sizes, symmetric] assms(1) assms(3)] sledgehammer *) by simp qed lemma polyfun_matricize: assumes "\x. dims (T x) = ds" assumes "\is. is \ ds \ polyfun N (\x. Tensor.lookup (T x) is)" assumes "\x. dim_row (matricize I (T x)) = nr" assumes "\x. dim_col (matricize I (T x)) = nc" assumes "i < nr" assumes "j < nc" shows "polyfun N (\x. matricize I (T x) $$ (i,j))" proof - let ?weave = "\ x. (weave I (digit_encode (nths ds I ) i) (digit_encode (nths ds (-I )) j))" have 1:"\x. matricize I (T x) $$ (i,j) = Tensor.lookup (T x) (?weave x)" unfolding matricize_def by (metis (no_types, lifting) assms(1) assms(3) assms(4) assms(5) assms(6) case_prod_conv dim_col_mat(1) dim_row_mat(1) index_mat(1) matricize_def) have "\x. ?weave x \ ds" using valid_index_weave(1) assms(2) digit_encode_valid_index dim_row_mat(1) matricize_def using assms digit_encode_valid_index matricize_def by (metis dim_col_mat(1)) then have "polyfun N (\x. Tensor.lookup (T x) (?weave x))" using assms(2) by simp then show ?thesis unfolding 1 using assms(1) by blast qed lemma "(\ (a::nat) < b) = (a \ b)" by (metis not_le) lemma polyfun_submatrix: assumes "\x. (A x) \ carrier_mat m n" assumes "\x i j. i j polyfun N (\x. (A x) $$ (i,j))" assumes "i < card {i. i < m \ i \ I}" assumes "j < card {j. j < n \ j \ J}" assumes "infinite I" "infinite J" shows "polyfun N (\x. (submatrix (A x) I J) $$ (i,j))" proof - have 1:"\x. (submatrix (A x) I J) $$ (i,j) = (A x) $$ (pick I i, pick J j)" using submatrix_index by (metis (no_types, lifting) Collect_cong assms(1) assms(3) assms(4) carrier_matD(1) carrier_matD(2)) have "pick I i < m" "pick J j < n" using card_le_pick_inf[OF `infinite I`] card_le_pick_inf[OF `infinite J`] `i < card {i. i < m \ i \ I}`[unfolded set_le_in] `j < card {j. j < n \ j \ J}`[unfolded set_le_in] not_less by metis+ then show ?thesis unfolding 1 by (simp add: assms(2)) qed context deep_model_correct_params_y begin definition witness_submatrix where "witness_submatrix f = submatrix (A' f) rows_with_1 rows_with_1" lemma polyfun_tensor_deep_model: assumes "is \ input_sizes (deep_model_l rs)" shows "polyfun {..f. Tensor.lookup (tensors_from_net (insert_weights shared_weights (deep_model_l rs) f) $ y) is)" proof - have 1:"\f. remove_weights (insert_weights shared_weights (deep_model_l rs) f) = deep_model_l rs" using remove_insert_weights by metis then have "y < output_size ( deep_model_l rs)" using valid_deep_model y_valid length_output_deep_model by force have 0:"{..f. A' f $$ (i,j))" proof - have 0:"y < output_size ( deep_model_l rs )" using valid_deep_model y_valid length_output_deep_model by force have 1:"\f. remove_weights (insert_weights shared_weights (deep_model_l rs) f) = deep_model_l rs" using remove_insert_weights by metis have 2:"(\f is. is \ replicate (2 * N_half) (last rs) \ polyfun {..x. Tensor.lookup (A x) is))" unfolding A_def using polyfun_tensor_deep_model[unfolded input_sizes_deep_model] 0 by blast show ?thesis unfolding A'_def A_def apply (rule polyfun_matricize) using dims_tensor_deep_model[OF 1] 2[unfolded A_def] using dims_A'_pow[unfolded A'_def A_def] `i<(last rs) ^ N_half` `j<(last rs) ^ N_half` by auto qed lemma polyfun_submatrix_deep_model: assumes "i < r ^ N_half" assumes "j < r ^ N_half" shows "polyfun {..f. witness_submatrix f $$ (i,j))" unfolding witness_submatrix_def proof (rule polyfun_submatrix) have 1:"\f. remove_weights (insert_weights shared_weights (deep_model_l rs) f) = deep_model_l rs" using remove_insert_weights by metis show "\f. A' f \ carrier_mat ((last rs) ^ N_half) ((last rs) ^ N_half)" using "1" dims_A'_pow using weight_space_dim_def by auto show "\f i j. i < last rs ^ N_half \ j < last rs ^ N_half \ polyfun {..f. A' f $$ (i, j))" using polyfun_matrix_deep_model weight_space_dim_def by force show "i < card {i. i < last rs ^ N_half \ i \ rows_with_1}" using assms(1) card_rows_with_1 dims_Aw'_pow set_le_in by metis show "j < card {i. i < last rs ^ N_half \ i \ rows_with_1}" using assms(2) card_rows_with_1 dims_Aw'_pow set_le_in by metis show "infinite rows_with_1" "infinite rows_with_1" by (simp_all add: infinite_rows_with_1) qed lemma polyfun_det_deep_model: shows "polyfun {..f. det (witness_submatrix f))" proof (rule polyfun_det) fix f have "remove_weights (insert_weights shared_weights (deep_model_l rs) f) = deep_model_l rs" using remove_insert_weights by metis show "witness_submatrix f \ carrier_mat (r ^ N_half) (r ^ N_half)" unfolding witness_submatrix_def apply (rule carrier_matI) unfolding dim_submatrix[unfolded set_le_in] unfolding dims_A'_pow[unfolded weight_space_dim_def] using card_rows_with_1 dims_Aw'_pow by simp_all show "\i j. i < r ^ N_half \ j < r ^ N_half \ polyfun {..f. witness_submatrix f $$ (i, j))" using polyfun_submatrix_deep_model by blast qed end end