diff --git a/thys/Deep_Learning/DL_Deep_Model.thy b/thys/Deep_Learning/DL_Deep_Model.thy --- a/thys/Deep_Learning/DL_Deep_Model.thy +++ b/thys/Deep_Learning/DL_Deep_Model.thy @@ -1,951 +1,953 @@ (* Author: Alexander Bentkamp, Universität des Saarlandes *) section \Deep Network Model\ theory DL_Deep_Model imports DL_Network Tensor_Matricization Jordan_Normal_Form.DL_Submatrix DL_Concrete_Matrices DL_Missing_Finite_Set Jordan_Normal_Form.DL_Missing_Sublist Jordan_Normal_Form.Determinant begin hide_const(open) Polynomial.order fun deep_model and deep_model' where "deep_model' Y [] = Input Y" | "deep_model' Y (r # rs) = Pool (deep_model Y r rs) (deep_model Y r rs)" | "deep_model Y r rs = Conv (Y,r) (deep_model' r rs)" abbreviation "deep_model'_l rs == deep_model' (rs!0) (tl rs)" abbreviation "deep_model_l rs == deep_model (rs!0) (rs!1) (tl (tl rs))" lemma valid_deep_model: "valid_net (deep_model Y r rs)" apply (induction rs arbitrary: Y r) apply (simp add: valid_net.intros(1) valid_net.intros(2)) using valid_net.intros(2) valid_net.intros(3) by auto lemma valid_deep_model': "valid_net (deep_model' r rs)" apply (induction rs arbitrary: r) apply (simp add: valid_net.intros(1)) by (metis deep_model'.elims deep_model'.simps(2) deep_model.elims output_size.simps valid_net.simps) lemma input_sizes_deep_model': assumes "length rs \ 1" shows "input_sizes (deep_model'_l rs) = replicate (2^(length rs - 1)) (last rs)" using assms proof (induction "butlast rs" arbitrary:rs) case Nil then have "rs = [rs!0]" by (metis One_nat_def diff_diff_cancel diff_zero length_0_conv length_Suc_conv length_butlast nth_Cons_0) then have "input_sizes (deep_model'_l rs) = [last rs]" by (metis deep_model'.simps(1) input_sizes.simps(1) last.simps list.sel(3)) then show "input_sizes (deep_model'_l rs) = replicate (2 ^ (length rs - 1)) (last rs)" by (metis One_nat_def \[] = butlast rs\ empty_replicate length_butlast list.size(3) power_0 replicate.simps(2)) next case (Cons r rs' rs) then have IH: "input_sizes (deep_model'_l (tl rs)) = replicate (2 ^ (length (tl rs) - 1)) (last rs)" by (metis (no_types, lifting) One_nat_def butlast_tl diff_is_0_eq' last_tl length_Cons length_butlast length_tl list.sel(3) list.size(3) nat_le_linear not_one_le_zero) have "rs = r # (tl rs)" by (metis Cons.hyps(2) Cons.prems One_nat_def append_Cons append_butlast_last_id length_greater_0_conv less_le_trans list.sel(3) zero_less_Suc) then have "deep_model'_l rs = Pool (deep_model_l rs) (deep_model_l rs)" by (metis Cons.hyps(2) One_nat_def butlast.simps(2) deep_model'.elims list.sel(3) list.simps(3) nth_Cons_0 nth_Cons_Suc) then have "input_sizes (deep_model'_l rs) = input_sizes (deep_model_l rs) @ input_sizes (deep_model_l rs)" using input_sizes.simps(3) by metis also have "... = input_sizes (deep_model'_l (tl rs)) @ input_sizes (deep_model'_l (tl rs))" by (metis (no_types, lifting) Cons.hyps(2) One_nat_def deep_model.elims input_sizes.simps(2) length_Cons length_butlast length_greater_0_conv length_tl list.sel(2) list.sel(3) list.size(3) nth_tl one_neq_zero) also have "... = replicate (2 ^ (length (tl rs) - 1)) (last rs) @ replicate (2 ^ (length (tl rs) - 1)) (last rs)" using IH by auto also have "... = replicate (2 ^ (length rs - 1)) (last rs)" using replicate_add[of "2 ^ (length (tl rs) - 1)" "2 ^ (length (tl rs) - 1)" "last rs"] by (metis Cons.hyps(2) One_nat_def butlast_tl length_butlast list.sel(3) list.size(4) mult_2_right power_add power_one_right) finally show ?case by auto qed lemma input_sizes_deep_model: assumes "length rs \ 2" shows "input_sizes (deep_model_l rs) = replicate (2^(length rs - 2)) (last rs)" proof - have "input_sizes (deep_model_l rs) = input_sizes (deep_model'_l (tl rs))" by (metis One_nat_def Suc_1 assms hd_Cons_tl deep_model.elims input_sizes.simps(2) length_Cons length_greater_0_conv lessI linorder_not_le list.size(3) not_numeral_le_zero nth_tl) also have "... = replicate (2^(length rs - 2)) (last rs)" using input_sizes_deep_model' by (metis (no_types, lifting) One_nat_def Suc_1 Suc_eq_plus1 assms diff_diff_left hd_Cons_tl last_tl length_Cons length_tl linorder_not_le list.size(3) not_less_eq not_numeral_le_zero numeral_le_one_iff semiring_norm(69)) finally show ?thesis by auto qed lemma evaluate_net_Conv_id: assumes "valid_net' m" and "input_sizes m = map dim_vec input" and "j Tensor.dims (?a$i)" then have "is \ input_sizes m" using `Tensor.dims (?a$i) = input_sizes m` by auto have "valid_net' Convm" by (simp add: assms eye_matrix_dim valid_net.intros(2) Convm_def) have "base_input m is = base_input Convm is" by (simp add: Convm_def base_input_def) have "i < output_size' Convm" unfolding Convm_def remove_weights.simps output_size.simps eye_matrix_dim using assms by metis have "is \ input_sizes (Conv (eye_matrix nr (output_size' m)) m)" by (metis \is \ input_sizes m\ input_sizes.simps(2)) then have f1: "lookup (tensors_from_net (Conv (eye_matrix nr (output_size' m)) m) $ i) is = evaluate_net (Conv (eye_matrix nr (output_size' m)) m) (base_input (Conv (eye_matrix nr (output_size' m)) m) is) $ i" using Convm_def \i < output_size' Convm\ \valid_net' Convm\ lookup_tensors_from_net by blast have "lookup (tensor0 (input_sizes m)) is = (0::real)" by (meson \is \ input_sizes m\ lookup_tensor0) then show "Tensor.lookup (?a $ i) is = Tensor.lookup ?b is" using Convm_def \base_input m is = base_input Convm is\ \is \ input_sizes m\ assms(1) assms(2) base_input_length evaluate_net_Conv_id f1 lookup_tensors_from_net by auto qed lemma evaluate_net_Conv_copy_first: assumes "valid_net' m" and "input_sizes m = map dim_vec input" and "j 0" shows "evaluate_net (Conv (copy_first_matrix nr (output_size' m)) m) input $ j = evaluate_net m input $ 0" unfolding evaluate_net.simps output_size_correct[OF assms(1) assms(2)[symmetric]] using mult_copy_first_matrix[OF `joutput_size' m = dim_vec (evaluate_net m input)\ assms(4)) lemma tensors_from_net_Conv_copy_first: assumes "valid_net' m" and "i 0" shows "tensors_from_net (Conv (copy_first_matrix nr (output_size' m)) m) $ i = tensors_from_net m $ 0" (is "?a $ i = ?b") proof (rule tensor_lookup_eqI) have "Tensor.dims (?a$i) = input_sizes m" by (metis assms(1) assms(2) copy_first_matrix_dim(1) copy_first_matrix_dim(2) dims_tensors_from_net input_sizes.simps(2) output_size.simps(2) output_size_correct_tensors remove_weights.simps(2) valid_net.intros(2) vec_setI) moreover have "Tensor.dims (?b) = input_sizes m" using dims_tensors_from_net output_size_correct_tensors[OF assms(1)] using assms(3) by (simp add: vec_setI) ultimately show "Tensor.dims (?a$i) = Tensor.dims (?b)" by auto define Convm where "Convm = Conv (copy_first_matrix nr (output_size' m)) m" fix "is" assume "is \ Tensor.dims (?a$i)" then have "is \ input_sizes m" using `Tensor.dims (?a$i) = input_sizes m` by auto have "valid_net' Convm" by (simp add: assms copy_first_matrix_dim valid_net.intros(2) Convm_def) have "base_input m is = base_input Convm is" by (simp add: Convm_def base_input_def) have "i < output_size' Convm" unfolding Convm_def remove_weights.simps output_size.simps copy_first_matrix_dim using assms by metis show "Tensor.lookup (?a $ i) is = Tensor.lookup ?b is" by (metis Convm_def \base_input m is = base_input Convm is\ \i < output_size' Convm\ \is \ input_sizes m\ \valid_net' Convm\ assms(1) assms(2) assms(3) base_input_length evaluate_net_Conv_copy_first input_sizes.simps(2) lookup_tensors_from_net) qed lemma evaluate_net_Conv_all1: assumes "valid_net' m" and "input_sizes m = map dim_vec input" and "i Tensor.dims (?a $ i)" then have "is \ input_sizes m" using \i < dim_vec ?a\ dims_tensors_from_net input_sizes.simps(2) by (metis vec_setI) then have "is \ input_sizes Convm" by (simp add: Convm_def) have "valid_net' Convm" by (simp add: Convm_def assms all1_matrix_dim valid_net.intros(2)) have "i< output_size' Convm" using Convm_def \i < dim_vec ?a\ \valid_net' Convm\ output_size_correct_tensors by presburger have "base_input Convm is = base_input m is" unfolding base_input_def Convm_def input_sizes.simps by metis have "Tensor.lookup (?a $ i) is = evaluate_net Convm (base_input Convm is) $ i" using lookup_tensors_from_net[OF `valid_net' Convm` `is \ input_sizes Convm` `i< output_size' Convm`] by (metis Convm_def ) also have "... = monoid_add_class.sum_list (list_of_vec (evaluate_net m (base_input Convm is)))" using evaluate_net_Conv_all1 Convm_def \is \ input_sizes Convm\ assms base_input_length \i < nr\ by simp also have "... = monoid_add_class.sum_list (list_of_vec (map_vec (\A. lookup A is)(tensors_from_net m)))" unfolding `base_input Convm is = base_input m is` using lookup_tensors_from_net[OF `valid_net' m` `is \ input_sizes m`] base_input_length[OF \is \ input_sizes m\] output_size_correct[OF assms(1)] output_size_correct_tensors[OF assms(1)] eq_vecI[of "evaluate_net m (base_input m is)" "map_vec (\A. lookup A is) (tensors_from_net m)"] index_map_vec(1) index_map_vec(2) by force also have "... = monoid_add_class.sum_list (map (\A. lookup A is) (list_of_vec (tensors_from_net m)))" using eq_vecI[of "vec_of_list (list_of_vec (map_vec (\A. lookup A is)(tensors_from_net m)))" "vec_of_list (map (\A. lookup A is) (list_of_vec (tensors_from_net m)))"] dim_vec_of_list nth_list_of_vec length_map list_vec nth_map index_map_vec(1) index_map_vec(2) vec_list by (metis (no_types, lifting)) also have "... = Tensor.lookup ?b is" using dims_tensors_from_net set_list_of_vec using lookup_listsum[OF `is \ input_sizes m`, of "list_of_vec (tensors_from_net m)"] by metis finally show "Tensor.lookup (?a $ i) is = Tensor.lookup ?b is" by blast qed fun witness and witness' where "witness' Y [] = Input Y" | "witness' Y (r # rs) = Pool (witness Y r rs) (witness Y r rs)" | "witness Y r rs = Conv ((if length rs = 0 then eye_matrix else (if length rs = 1 then all1_matrix else copy_first_matrix)) Y r) (witness' r rs)" abbreviation "witness_l rs == witness (rs!0) (rs!1) (tl (tl rs))" abbreviation "witness'_l rs == witness' (rs!0) (tl rs)" lemma witness_is_deep_model: "remove_weights (witness Y r rs) = deep_model Y r rs" proof (induction rs arbitrary: Y r) case Nil then show ?case unfolding witness.simps witness'.simps deep_model.simps deep_model'.simps by (simp add: eye_matrix_dim) next case (Cons r' rs Y r) have "dim_row ((if length (r' # rs) = 0 then eye_matrix else (if length (r' # rs) = 1 then all1_matrix else copy_first_matrix)) Y r) = Y" "dim_col ((if length (r' # rs) = 0 then eye_matrix else (if length (r' # rs) = 1 then all1_matrix else copy_first_matrix)) Y r) = r" by (simp_all add: all1_matrix_dim copy_first_matrix_dim) then show ?case unfolding witness.simps unfolding witness'.simps unfolding remove_weights.simps using Cons by simp qed lemma witness'_is_deep_model: "remove_weights (witness' Y rs) = deep_model' Y rs" proof (induction rs arbitrary: Y) case Nil then show ?case unfolding witness.simps witness'.simps deep_model.simps deep_model'.simps by (simp add: eye_matrix_dim) next case (Cons r rs Y) have "dim_row ((if length rs = 0 then eye_matrix else (if length rs = 1 then all1_matrix else copy_first_matrix)) Y r) = Y" "dim_col ((if length rs = 0 then eye_matrix else (if length rs = 1 then all1_matrix else copy_first_matrix)) Y r) = r" by (simp_all add: all1_matrix_dim copy_first_matrix_dim eye_matrix_dim) then show ?case unfolding witness'.simps unfolding witness.simps unfolding remove_weights.simps using Cons by simp qed lemma witness_valid: "valid_net' (witness Y r rs)" using valid_deep_model witness_is_deep_model by auto lemma witness'_valid: "valid_net' (witness' Y rs)" using valid_deep_model' witness'_is_deep_model by auto lemma shared_weight_net_witness: "shared_weight_net (witness Y r rs)" proof (induction rs arbitrary:Y r) case Nil then show ?case unfolding witness.simps witness'.simps by (simp add: shared_weight_net_Conv shared_weight_net_Input) next case (Cons a rs) then show ?case unfolding witness.simps witness'.simps by (simp add: shared_weight_net_Conv shared_weight_net_Input shared_weight_net_Pool) qed lemma witness_l0': "witness' Y [M] = (Pool (Conv (eye_matrix Y M) (Input M)) (Conv (eye_matrix Y M) (Input M)) )" unfolding witness'.simps witness.simps by simp lemma witness_l1: "witness Y r0 [M] = Conv (all1_matrix Y r0) (witness' r0 [M])" unfolding witness'.simps by simp lemma tensors_ht_l0: assumes "j unit_vec M j = tensor_from_lookup [M,M] (\is. if is=[j,j] then 1 else 0)" (is "?A=?B") proof (rule tensor_lookup_eqI) show "Tensor.dims ?A = Tensor.dims ?B" by (metis append_Cons self_append_conv2 dims_unit_vec dims_tensor_prod dims_tensor_from_lookup) fix "is" assume is_valid:"is \ Tensor.dims (unit_vec M j \ unit_vec M j)" then have "is \ [M,M]" by (metis append_Cons self_append_conv2 dims_unit_vec dims_tensor_prod) then obtain i1 i2 where is_split: "is = [i1, i2]" "i1 < M" "i2 < M" using list.distinct(1) by blast then have "[i1] \ Tensor.dims (unit_vec M j)" "[i2] \ Tensor.dims (unit_vec M j)" by (simp_all add: valid_index.Cons valid_index.Nil dims_unit_vec) have "is = [i1] @ [i2]" by (simp add: is_split(1)) show "Tensor.lookup ?A is = Tensor.lookup ?B is" unfolding `is = [i1] @ [i2]` lookup_tensor_prod[OF `[i1] \ Tensor.dims (unit_vec M j)` `[i2] \ Tensor.dims (unit_vec M j)`] lookup_tensor_from_lookup[OF \is \ [M, M]\, unfolded `is = [i1] @ [i2]`] lookup_unit_vec[OF \i1 < M\] lookup_unit_vec[OF \i2 < M\] by fastforce qed lemma tensors_ht_l0': assumes "j unit_vec M j else tensor0 [M,M])" (is "_ = ?b") proof - have "valid_net' (Conv (eye_matrix r0 M) (Input M))" by (metis convnet.inject(3) list.discI witness'.elims witness_l0' witness_valid) have j_le:"j < dim_vec (tensors_from_net (Conv (eye_matrix r0 M) (Input M)))" using output_size_correct_tensors[OF `valid_net' (Conv (eye_matrix r0 M) (Input M))`, unfolded remove_weights.simps output_size.simps eye_matrix_dim] assms by simp show ?thesis unfolding tensors_from_net.simps(3) witness_l0' index_component_mult[OF j_le j_le] tensors_ht_l0[OF assms] by auto qed lemma lookup_tensors_ht_l0': assumes "j [M,M]" shows "(Tensor.lookup (tensors_from_net (witness' r0 [M]) $ j)) is = (if is=[j,j] then 1 else 0)" proof (cases "jj [j, j]" using assms(2) using list.distinct(1) nth_Cons_0 valid_index.simps by blast show ?thesis unfolding tensors_ht_l0'[OF assms(1)] tensor_prod_unit_vec using `\j [j, j]`) qed lemma lookup_tensors_ht_l1: assumes "j < r1" and "is \ [M,M]" shows "Tensor.lookup (tensors_from_net (witness r1 r0 [M]) $ j) is = (if is!0 = is!1 \ is!0output_size' (witness' r0 [M]) = r0\ witness_l0'_valid output_size_correct_tensors by fastforce have all0_but1:"\i. i\is!0 \ i Tensor.lookup (tensors_from_net (witness' r0 [M]) $ i) is = 0" using lookup_tensors_ht_l0' \is \ [M, M]\ by auto have "tensors_from_net (witness r1 r0 [M]) $ j = Tensor_Plus.listsum [M,M] (list_of_vec (tensors_from_net (witness' r0 [M])))" unfolding witness_l1 using tensors_from_net_Conv_all1[OF witness_l0'_valid assms(1)] witness_l0' `output_size' (witness' r0 [M]) = r0` by simp then have "Tensor.lookup (tensors_from_net (witness r1 r0 [M]) $ j) is = monoid_add_class.sum_list (map (\A. Tensor.lookup A is) (list_of_vec (tensors_from_net (witness' r0 [M]))))" using lookup_listsum[OF `is \ [M, M]`] \input_sizes (witness' r0 [M]) = [M, M]\ dims_tensors_from_net by (metis set_list_of_vec) also have "... = monoid_add_class.sum_list (map (\i. lookup (tensors_from_net (witness' r0 [M]) $ i) is) [0..A. Tensor.lookup A is)" "\i. (tensors_from_net (witness' r0 [M]) $ i)" "[0..i is!0 {0..i {0..i. (Tensor.lookup (tensors_from_net (witness' r0 [M])$i) is)"] using all0_but1 atLeast0LessThan by force then show ?thesis using lookup_tensors_ht_l0' \is ! 0 < r0\ \is \ [M, M]\ by fastforce next case False then show ?thesis using all0_but1 atLeast0LessThan sum.neutral by force qed finally show ?thesis by auto qed lemma length_output_deep_model: assumes "remove_weights m = deep_model_l rs" shows "dim_vec (tensors_from_net m) = rs ! 0" using output_size_correct_tensors valid_deep_model deep_model.elims output_size.simps(2) by (metis assms) lemma length_output_deep_model': assumes "remove_weights m = deep_model'_l rs" shows "dim_vec (tensors_from_net m) = rs ! 0" using output_size_correct_tensors valid_deep_model' deep_model'.elims output_size.simps by (metis assms deep_model.elims) lemma length_output_witness: "dim_vec (tensors_from_net (witness_l rs)) = rs ! 0" using length_output_deep_model witness_is_deep_model by blast lemma length_output_witness': "dim_vec (tensors_from_net (witness'_l rs)) = rs ! 0" using length_output_deep_model' witness'_is_deep_model by blast lemma dims_output_deep_model: assumes "length rs \ 2" and "\r. r\set rs \ r > 0" and "j < rs!0" and "remove_weights m = deep_model_l rs" shows "Tensor.dims (tensors_from_net m $ j) = replicate (2^(length rs - 2)) (last rs)" using dims_tensors_from_net input_sizes_deep_model[OF assms(1)] output_size_correct_tensors valid_deep_model assms(3) assms(4) input_sizes_remove_weights length_output_witness witness_is_deep_model by (metis vec_setI) lemma dims_output_witness: assumes "length rs \ 2" and "\r. r\set rs \ r > 0" and "j < rs!0" shows "Tensor.dims (tensors_from_net (witness_l rs) $ j) = replicate (2^(length rs - 2)) (last rs)" using dims_output_deep_model witness_is_deep_model assms by blast lemma dims_output_deep_model': assumes "length rs \ 1" and "\r. r\set rs \ r > 0" and "j < rs!0" and "remove_weights m = deep_model'_l rs" shows "Tensor.dims (tensors_from_net m $ j) = replicate (2^(length rs - 1)) (last rs)" proof - have "dim_vec (tensors_from_net m) > j" using length_output_deep_model' `remove_weights m = deep_model'_l rs` `j < rs!0` by auto then have "Tensor.dims (tensors_from_net m $ j) = input_sizes m" using dims_tensors_from_net[of _ m] output_size_correct_tensors vec_setI by metis then show ?thesis using assms(1) input_sizes_deep_model' input_sizes_remove_weights[of m, unfolded `remove_weights m = deep_model'_l rs`] by auto qed lemma dims_output_witness': assumes "length rs \ 1" and "\r. r\set rs \ r > 0" and "j < rs!0" shows "Tensor.dims (tensors_from_net (witness'_l rs) $ j) = replicate (2^(length rs - 1)) (last rs)" using dims_output_deep_model' assms witness'_is_deep_model by blast abbreviation "ten2mat == matricize {n. even n}" abbreviation "mat2ten == dematricize {n. even n}" locale deep_model_correct_params = fixes shared_weights::bool fixes rs::"nat list" assumes deep:"length rs \ 3" and no_zeros:"\r. r\set rs \ 0 < r" begin definition "r = min (last rs) (last (butlast rs))" definition "N_half = 2^(length rs - 3)" definition "weight_space_dim = count_weights shared_weights (deep_model_l rs)" end locale deep_model_correct_params_y = deep_model_correct_params + fixes y::nat assumes y_valid:"y < rs ! 0" begin -definition "A ws = tensors_from_net (insert_weights shared_weights (deep_model_l rs) ws) $ y" -definition "A' ws = ten2mat (A ws)" +definition A :: "(nat \ real) \ real tensor" + where "A ws = tensors_from_net (insert_weights shared_weights (deep_model_l rs) ws) $ y" +definition A' :: "(nat \ real) \ real mat" + where "A' ws = ten2mat (A ws)" lemma dims_tensor_deep_model: assumes "remove_weights m = deep_model_l rs" shows "dims (tensors_from_net m $ y) = replicate (2 * N_half) (last rs)" proof - have "dims (tensors_from_net m $ y) = replicate (2 ^ (length rs - 2)) (last rs)" using dims_output_deep_model[OF _ no_zeros y_valid assms] using less_imp_le_nat Suc_le_lessD deep numeral_3_eq_3 by auto then show ?thesis using N_half_def by (metis One_nat_def Suc_1 Suc_eq_plus1 Suc_le_lessD deep diff_diff_left less_numeral_extra(3) numeral_3_eq_3 realpow_num_eq_if zero_less_diff) qed lemma order_tensor_deep_model: assumes "remove_weights m = deep_model_l rs" shows "order (tensors_from_net m $ y) = 2 * N_half" using dims_tensor_deep_model by (simp add: assms) lemma dims_A: shows "Tensor.dims (A ws) = replicate (2 * N_half) (last rs)" unfolding A_def using dims_tensor_deep_model remove_insert_weights by blast lemma order_A: shows "order (A ws) = 2 * N_half" using dims_A length_replicate by auto lemma dims_A': shows "dim_row (A' ws) = prod_list (nths (Tensor.dims (A ws)) {n. even n})" and "dim_col (A' ws) = prod_list (nths (Tensor.dims (A ws)) {n. odd n})" unfolding A'_def matricize_def by (simp_all add: A_def Collect_neg_eq) lemma dims_A'_pow: shows "dim_row (A' ws) = (last rs) ^ N_half" "dim_col (A' ws) = (last rs) ^ N_half" unfolding dims_A' dims_A nths_replicate set_le_in card_even card_odd prod_list_replicate by simp_all definition "Aw = tensors_from_net (witness_l rs) $ y" definition "Aw' = ten2mat Aw" definition "witness_weights = extract_weights shared_weights (witness_l rs)" lemma witness_weights:"witness_l rs = insert_weights shared_weights (deep_model_l rs) witness_weights" by (metis (full_types) insert_extract_weights_cong_shared insert_extract_weights_cong_unshared shared_weight_net_witness witness_is_deep_model witness_weights_def) lemma Aw_def': "Aw = A witness_weights" unfolding Aw_def A_def using witness_weights by auto lemma Aw'_def': "Aw' = A' witness_weights" unfolding Aw'_def A'_def Aw_def' by auto lemma dims_Aw: "Tensor.dims Aw = replicate (2 * N_half) (last rs)" unfolding Aw_def' using dims_A by auto lemma order_Aw: "order Aw = 2 * N_half" unfolding Aw_def' using order_A by auto lemma dims_Aw': "dim_row Aw' = prod_list (nths (Tensor.dims Aw) {n. even n})" "dim_col Aw' = prod_list (nths (Tensor.dims Aw) {n. odd n})" unfolding Aw'_def' Aw_def' using dims_A' by auto lemma dims_Aw'_pow: "dim_row Aw' = (last rs) ^ N_half" "dim_col Aw' = (last rs) ^ N_half" unfolding Aw'_def' Aw_def' using dims_A'_pow by auto lemma witness_tensor: assumes "is \ Tensor.dims Aw" shows "Tensor.lookup Aw is = (if nths is {n. even n} = nths is {n. odd n} \ (\i\set is. i < last (butlast rs)) then 1 else 0)" using assms deep no_zeros y_valid unfolding Aw_def proof (induction "butlast (butlast (butlast rs))" arbitrary:rs "is" y) case Nil have "length rs = 3" by (rule antisym, metis Nil.hyps One_nat_def Suc_1 Suc_eq_plus1 add_2_eq_Suc' diff_diff_left length_butlast less_numeral_extra(3) list.size(3) not_le numeral_3_eq_3 zero_less_diff, metis `3 \ length rs`) then have "rs = [rs!0, rs!1, rs!2]" by (metis (no_types, lifting) Cons_nth_drop_Suc One_nat_def Suc_eq_plus1 append_Nil id_take_nth_drop length_0_conv length_tl lessI list.sel(3) list.size(4) not_le numeral_3_eq_3 numeral_le_one_iff one_add_one semiring_norm(70) take_0 zero_less_Suc) have "input_sizes (witness_l [rs ! 0, rs ! 1, rs ! 2]) = [rs!2, rs!2]" using witness.simps witness'.simps input_sizes.simps by auto then have "Tensor.dims (tensors_from_net (witness_l rs) $ y) = [rs!2, rs!2]" using dims_tensors_from_net[of "tensors_from_net (witness_l rs) $ y" "witness_l rs"] Nil.prems(4) length_output_witness \rs = [rs ! 0, rs ! 1, rs ! 2]\ vec_setI by metis then have "is \ [rs!2, rs!2]" using Nil.prems by metis then have "Tensor.lookup ((tensors_from_net (witness_l rs))$y) is = (if is ! 0 = is ! 1 \ is ! 0 < rs ! 1 then 1 else 0)" using Nil.prems(4) \rs = [rs ! 0, rs ! 1, rs ! 2]\ by (metis list.sel(3) lookup_tensors_ht_l1) have "is ! 0 = is ! 1 \ is ! 0 < rs ! 1 \ nths is {n. even n} = nths is {n. odd n} \ (\i\set is. i < last (butlast rs))" proof - have "length is = 2" by (metis One_nat_def Suc_eq_plus1 \is \ [rs ! 2, rs ! 2]\ list.size(3) list.size(4) numeral_2_eq_2 valid_index_length) have "nths is {n. even n} = [is!0]" apply (rule nths_only_one) using subset_antisym less_2_cases `length is = 2` by fastforce have "nths is {n. odd n} = [is!1]" apply (rule nths_only_one) using subset_antisym less_2_cases `length is = 2` by fastforce have "last (butlast rs) = rs!1" by (metis One_nat_def Suc_eq_plus1 \rs = [rs ! 0, rs ! 1, rs ! 2]\ append_butlast_last_id last_conv_nth length_butlast length_tl lessI list.sel(3) list.simps(3) list.size(3) list.size(4) nat.simps(3) nth_append) show ?thesis unfolding `last (butlast rs) = rs!1` apply (rule iffI; rule conjI) apply (simp add: \nths is (Collect even) = [is ! 0]\ \nths is {n. odd n} = [is ! 1]\) apply (metis `length is = 2` One_nat_def in_set_conv_nth less_2_cases) apply (simp add: \nths is (Collect even) = [is ! 0]\ \nths is {n. odd n} = [is ! 1]\) apply (simp add: \length is = 2\) done qed then show ?case unfolding \Tensor.lookup (tensors_from_net (witness_l rs) $ y) is = (if is ! 0 = is ! 1 \ is ! 0 < rs ! 1 then 1 else 0)\ using witness_is_deep_model witness_valid \rs = [rs ! 0, rs ! 1, rs ! 2]\ by auto next case (Cons r rs' rs "is" j) text \We prove the Induction Hypothesis for "tl rs" and j=0:\ have "rs = r # tl rs" by (metis Cons.hyps(2) append_butlast_last_id butlast.simps(1) hd_append2 list.collapse list.discI list.sel(1)) have 1:"rs' = butlast (butlast (butlast (tl rs)))" by (metis Cons.hyps(2) butlast_tl list.sel(3)) have 2:"3 \ length (tl rs)" by (metis (no_types, lifting) Cons.hyps(2) Cons.prems(2) Nitpick.size_list_simp(2) One_nat_def Suc_eq_plus1 \rs = r # tl rs\ \rs' = butlast (butlast (butlast (tl rs)))\ diff_diff_left diff_self_eq_0 gr0_conv_Suc le_Suc_eq length_butlast length_tl less_numeral_extra(3) list.simps(3) numeral_3_eq_3) have 3:"\r. r \ set (tl rs) \ 0 < r" by (metis Cons.prems(3) list.sel(2) list.set_sel(2)) have 4:"0 < (tl rs) ! 0" using "2" "3" by auto have IH: "\is'. is' \ Tensor.dims (tensors_from_net (witness_l (tl rs)) $ 0) \ Tensor.lookup (tensors_from_net (witness_l (tl rs)) $ 0) is' = (if nths is' (Collect even) = nths is' {n. odd n} \ (\i\set is'. i < last (butlast (tl rs))) then 1 else 0)" using "1" "2" "3" 4 Cons.hyps(1) by blast text \The list "is" can be split in two parts:\ have "is \ replicate (2^(length rs - 2)) (last rs)" using Cons.prems(3) dims_output_witness 2 by (metis (no_types, lifting) Cons.prems(1) Cons.prems(3) Cons.prems(4) Nitpick.size_list_simp(2) One_nat_def diff_diff_left diff_is_0_eq length_tl nat_le_linear not_numeral_le_zero numeral_le_one_iff one_add_one semiring_norm(70)) then have "is \ replicate (2^(length (tl rs) - 2)) (last rs) @ replicate (2^(length (tl rs) - 2)) (last rs)" using Cons.prems dims_output_witness by (metis "2" Nitpick.size_list_simp(2) One_nat_def diff_diff_left length_tl mult_2 not_numeral_le_zero numeral_le_one_iff one_add_one power.simps(2) replicate_add semiring_norm(70)) then obtain is1 is2 where "is = is1 @ is2" and is1_replicate: "is1 \ replicate (2^(length (tl rs) - 2)) (last rs)" and is2_replicate: "is2 \ replicate (2^(length (tl rs) - 2)) (last rs)" by (metis valid_index_split) then have is1_valid: "is1 \ Tensor.dims (tensors_from_net (witness_l (tl rs)) $ 0)" (is ?is1) and is2_valid: "is2 \ Tensor.dims (tensors_from_net (witness_l (tl rs)) $ 0)" (is ?is2) proof - have "last (tl rs) = last rs" by (metis "2" \rs = r # tl rs\ last_ConsR list.size(3) not_numeral_le_zero) then show ?is1 ?is2 using dims_output_witness[of "tl rs"] using dims_output_witness[of "tl rs"] 2 3 is1_replicate is2_replicate \last (tl rs) = last rs\ by auto qed text \A shorthand for the condition to find a "1" in the tensor:\ let ?cond = "\is rs. nths is {n. even n} = nths is {n. odd n} \ (\i\set is. i < last (butlast rs))" text \We can use the IH on our newly created is1 and is2:\ have IH_is12: "Tensor.lookup (tensors_from_net (witness_l (tl rs)) $ 0) is1 = (if (?cond is1 (tl rs)) then 1 else 0)" "Tensor.lookup (tensors_from_net (witness_l (tl rs)) $ 0) is2 = (if (?cond is2 (tl rs)) then 1 else 0)" using IH is1_valid is2_valid by fast+ text \In the induction step we have to add two layers: first the Pool layer, then the Conv layer. The Pool layer connects the two subtrees. Therefore the two conditions on is1 and is2 become one, and we have to prove that they are equivalent:\ have "?cond is1 (tl rs) \ ?cond is2 (tl rs) \ ?cond is rs" proof - have "length is1 = 2 ^ (length (tl rs) - 2)" "length is2 = 2 ^ (length (tl rs) - 2)" using is1_replicate is2_replicate by (simp_all add: valid_index_length) then have "even (length is1)" "even (length is2)" by (metis Cons.hyps(2) One_nat_def add_gr_0 diff_diff_left even_numeral even_power length_butlast length_tl list.size(4) one_add_one zero_less_Suc)+ then have "{j. j + length is1 \ {n. even n}} = {n. even n}" "{j. j + length is1 \ {n. odd n}} = {n. odd n}" by simp_all have "length (nths is2 (Collect even)) = length (nths is2 (Collect odd))" using length_nths_even \even (length is2)\ by blast have cond1_iff: "(nths is1 (Collect even) = nths is1 {n. odd n} \ nths is2 (Collect even) = nths is2 {n. odd n}) = (nths is (Collect even) = nths is {n. odd n})" unfolding `is = is1 @ is2` nths_append `{j. j + length is1 \ {n. odd n}} = {n. odd n}` `{j. j + length is1 \ {n. even n}} = {n. even n}` by (simp add: \length (nths is2 (Collect even)) = length (nths is2 (Collect odd))\) have "last (butlast (tl rs)) = last (butlast rs)" using Nitpick.size_list_simp(2) \even (length is1)\ \length is1 = 2 ^ (length (tl rs) - 2)\ butlast_tl last_tl length_butlast length_tl not_less_eq zero_less_diff by (metis (full_types) Cons.hyps(2) length_Cons less_nat_zero_code) have cond2_iff: "(\i\set is1. i < last (butlast (tl rs))) \ (\i\set is2. i < last (butlast (tl rs))) \ (\i\set is. i < last (butlast rs))" unfolding `last (butlast (tl rs)) = last (butlast rs)` `is = is1 @ is2` set_append by blast then show ?thesis using cond1_iff cond2_iff by blast qed text \Now we can make the Pool layer step: \ have lookup_witness': "Tensor.lookup ((tensors_from_net (witness' (rs ! 1) (tl (tl rs)))) $ 0) is = (if ?cond is rs then 1 else 0)" proof - have lookup_prod: "Tensor.lookup ((tensors_from_net (witness_l (tl rs)) $ 0) \ (tensors_from_net (witness_l (tl rs))) $ 0) is = (if ?cond is rs then 1 else 0)" using `?cond is1 (tl rs) \ ?cond is2 (tl rs) \ ?cond is rs` unfolding `is = is1 @ is2` lookup_tensor_prod[OF is1_valid is2_valid] IH_is12 by auto have witness_l_tl: "witness_l (tl rs) = witness (rs ! 1) (rs ! 2) (tl (tl (tl rs)))" by (metis One_nat_def Suc_1 \rs = r # tl rs\ nth_Cons_Suc) have tl_tl:"(tl (tl rs)) = ((rs ! 2) # tl (tl (tl rs)))" proof - have "length (tl (tl rs)) \ 0" by (metis One_nat_def Suc_eq_plus1 diff_diff_left diff_is_0_eq length_tl not_less_eq_eq Cons.prems(2) numeral_3_eq_3) then have "tl (tl rs) \ []" by fastforce then show ?thesis by (metis list.exhaust_sel nth_Cons_0 nth_Cons_Suc numeral_2_eq_2 tl_Nil) qed have length_gt0:"dim_vec (tensors_from_net (witness (rs ! 1) (rs ! 2) (tl (tl (tl rs))))) > 0" using output_size_correct_tensors[of "witness (rs ! 1) (rs ! 2) (tl (tl (tl rs)))"] witness_is_deep_model[of "rs ! 1" "rs ! 2" "tl (tl (tl rs))"] valid_deep_model[of "rs ! 1" "rs ! 2" "tl (tl (tl rs))"] output_size.simps witness.simps by (metis "2" "3" One_nat_def \rs = r # tl rs\ deep_model.elims length_greater_0_conv list.size(3) not_numeral_le_zero nth_Cons_Suc nth_mem) then have "tensors_from_net (witness' (rs ! 1) ((rs ! 2) # tl (tl (tl rs)))) $ 0 = (tensors_from_net (witness_l (tl rs)) $ 0) \ (tensors_from_net (witness_l (tl rs)) $ 0)" unfolding witness'.simps tensors_from_net.simps witness_l_tl using index_component_mult by blast then show ?thesis using lookup_prod tl_tl by simp qed text \Then we can make the Conv layer step: \ show ?case proof - have "valid_net' (witness' (rs ! 1) (tl (tl rs)))" by (simp add: witness'_valid) have "output_size' (witness' (rs ! 1) (tl (tl rs))) = rs ! 1" by (metis "2" Nitpick.size_list_simp(2) diff_diff_left diff_is_0_eq hd_Cons_tl deep_model'.simps(2) deep_model.elims length_tl not_less_eq_eq numeral_2_eq_2 numeral_3_eq_3 one_add_one output_size.simps(2) output_size.simps(3) tl_Nil witness'_is_deep_model) have if_resolve:"(if length (tl (tl rs)) = 0 then eye_matrix else if length (tl (tl rs)) = 1 then all1_matrix else copy_first_matrix) = copy_first_matrix" by (metis "2" Cons.prems(2) Nitpick.size_list_simp(2) One_nat_def Suc_n_not_le_n not_numeral_le_zero numeral_3_eq_3) have "tensors_from_net (Conv (copy_first_matrix (rs ! 0) (rs ! 1)) (witness' (rs ! 1) (tl (tl rs)))) $ j = tensors_from_net (witness' (rs ! 1) (tl (tl rs))) $ 0" using tensors_from_net_Conv_copy_first[OF `valid_net' (witness' (rs ! 1) (tl (tl rs)))` `j < rs ! 0`, unfolded `output_size' (witness' (rs ! 1) (tl (tl rs))) = rs ! 1`] using "4" One_nat_def \rs = r # tl rs\ nth_Cons_Suc by metis then show ?thesis unfolding witness.simps if_resolve `output_size' (witness' (rs ! 1) (tl (tl rs))) = rs ! 1` using lookup_witness' \valid_net' (witness' (rs ! 1) (tl (tl rs)))\ hd_conv_nth output_size_correct_tensors by fastforce qed qed lemma witness_matricization: assumes "i < dim_row Aw'" and "j < dim_col Aw'" shows "Aw' $$ (i, j) = (if i=j \ (\i0\set (digit_encode (nths (Tensor.dims Aw) {n. even n}) i). i0 < last (butlast rs)) then 1 else 0)" proof - define "is" where "is = weave {n. even n} (digit_encode (nths (Tensor.dims Aw) {n. even n}) i) (digit_encode (nths (Tensor.dims Aw) {n. odd n}) j)" have lookup_eq: "Aw' $$ (i, j) = Tensor.lookup Aw is" using Aw'_def matricize_def dims_Aw'(1)[symmetric, unfolded A_def] dims_Aw'(2)[symmetric, unfolded A_def Collect_neg_eq] index_mat(1)[OF `i < dim_row Aw'` `j < dim_col Aw'`] is_def Collect_neg_eq case_prod_conv by (metis (no_types) Aw'_def Collect_neg_eq case_prod_conv is_def matricize_def) have "is \ Tensor.dims Aw" using is_def valid_index_weave A_def Collect_neg_eq assms digit_encode_valid_index dims_Aw' by metis have "even (order Aw)" unfolding Aw_def using assms dims_output_witness even_numeral le_eq_less_or_eq numeral_2_eq_2 numeral_3_eq_3 deep no_zeros y_valid by fastforce have nths_dimsAw: "nths (Tensor.dims Aw) (Collect even) = nths (Tensor.dims Aw) {n. odd n}" proof - have 0:"Tensor.dims (tensors_from_net (witness_l rs) $ y) = replicate (2 ^ (length rs - 2)) (last rs)" using dims_output_witness[OF _ no_zeros y_valid] using deep by linarith show ?thesis unfolding A_def using nths_replicate by (metis (no_types, lifting) "0" Aw_def \even (order Aw)\ length_replicate length_nths_even) qed have "i = j \ nths is (Collect even) = nths is {n. odd n}" proof have eq_lengths: "length (digit_encode (nths (Tensor.dims Aw) (Collect even)) i) = length (digit_encode (nths (Tensor.dims Aw) {n. odd n}) j)" unfolding length_digit_encode by (metis \even (order Aw)\ length_nths_even) then show "i = j \ nths is (Collect even) = nths is {n. odd n}" unfolding is_def using nths_weave[of "digit_encode (nths (Tensor.dims Aw) (Collect even)) i" "Collect even" "digit_encode (nths (Tensor.dims Aw) {n. odd n}) j", unfolded eq_lengths, unfolded Collect_neg_eq[symmetric] card_even mult_2[symmetric] card_odd] nths_dimsAw by simp show "nths is (Collect even) = nths is {n. odd n} \ i = j" unfolding is_def using nths_weave[of "digit_encode (nths (Tensor.dims Aw) (Collect even)) i" "Collect even" "digit_encode (nths (Tensor.dims Aw) {n. odd n}) j", unfolded eq_lengths, unfolded Collect_neg_eq[symmetric] card_even mult_2[symmetric] card_odd] using \nths (Tensor.dims Aw) (Collect even) = nths (Tensor.dims Aw) {n. odd n}\ deep no_zeros y_valid assms digit_decode_encode dims_Aw' by auto (metis digit_decode_encode_lt) qed have "i=j \ set (digit_encode (nths (Tensor.dims Aw) {n. even n}) i) = set is" unfolding is_def nths_dimsAw using set_weave[of "(digit_encode (nths (Tensor.dims Aw) {n. odd n}) j)" "Collect even" "(digit_encode (nths (Tensor.dims Aw) {n. odd n}) j)", unfolded mult_2[symmetric] card_even Collect_neg_eq[symmetric] card_odd] Un_absorb card_even card_odd mult_2 by blast then show ?thesis unfolding lookup_eq using witness_tensor[OF `is \ Tensor.dims Aw`] by (simp add: A_def \(i = j) = (nths is (Collect even) = nths is {n. odd n})\) qed definition "rows_with_1 = {i. (\i0\set (digit_encode (nths (Tensor.dims Aw) {n. even n}) i). i0 < last (butlast rs))}" lemma card_low_digits: assumes "m>0" "\d. d\set ds \ m \ d" shows "card {i. i (\i0\set (digit_encode ds i). i0 < m)} = m ^ (length ds)" using assms proof (induction ds) case Nil then show ?case using prod_list.Nil by simp next case (Cons d ds) define low_digits where "low_digits ds i \ i < prod_list ds \ (\i0\set (digit_encode ds i). i0 < m)" for ds i have "card {i. low_digits ds i} = m ^ (length ds)" unfolding low_digits_def by (simp add: Cons.IH Cons.prems(1) Cons.prems(2)) have "card {i. low_digits (d # ds) i} = card ({.. {i. low_digits ds i})" proof - define f where "f p = fst p + d * snd p" for p have "inj_on f ({.. {i. low_digits ds i})" proof (rule inj_onI) fix x y assume "x \ {.. {i. low_digits ds i}" "y \ {.. {i. low_digits ds i}" "f x = f y" then have "fst xf x = f y\ \f x mod d = fst x\ \fst y < d\ f_def by auto show "x = y" using \f x = f y\ \f x div d = snd x\ \f x mod d = fst x\ \f y div d = snd y\ \f y mod d = fst y\ prod_eqI by fastforce qed have "f ` ({.. {i. low_digits ds i}) = {i. low_digits (d # ds) i}" proof (rule subset_antisym; rule subsetI) fix x assume "x \ f ` ({.. {i. low_digits ds i})" then obtain i0 i1 where "x = i0 + d * i1" "i0 < m" "low_digits ds i1" using f_def by force then have "i0 {i. low_digits (d # ds) i}" unfolding low_digits_def proof (rule; rule conjI) have "i1 < prod_list ds" "\i0\set (digit_encode ds i1). i0 < m" using `low_digits ds i1` low_digits_def by auto show "x < prod_list (d # ds)" unfolding prod_list.Cons `x = i0 + d * i1` using `i0 0" by (metis \i0 < d\ gr_implies_not0) then have "(i0 + d * i1) div (d * prod_list ds) = 0" by (simp add: Divides.div_mult2_eq \i0 < d\ \i1 < prod_list ds\) then show "i0 + d * i1 < d * prod_list ds" by (metis (no_types) \i0 < d\ \i1 < prod_list ds\ div_eq_0_iff gr_implies_not0 no_zero_divisors) qed show "\i0\set (digit_encode (d # ds) x). i0 < m" using \\i0\set (digit_encode ds i1). i0 < m\ \i0 < d\ \i0 < m\ \x = i0 + d * i1\ by auto qed next fix x assume "x \ {i. low_digits (d # ds) i}" then have "x < prod_list (d # ds)" "\i0\set (digit_encode (d # ds) x). i0 < m" using low_digits_def by auto have "x mod d < m" using `\i0\set (digit_encode (d # ds) x). i0 < m`[unfolded digit_encode.simps] by simp have "x div d < prod_list ds" using `x < prod_list (d # ds)`[unfolded prod_list.Cons] by (metis div_eq_0_iff div_mult2_eq mult_0_right not_less0) have "\i0\set (digit_encode ds (x div d)). i0 < m" by (simp add: \\i0\set (digit_encode (d # ds) x). i0 < m\) have "f ((x mod d),(x div d)) = x" by (simp add: f_def) show "x \ f ` ({.. {i. low_digits ds i})" by (metis SigmaI \\i0\set (digit_encode ds (x div d)). i0 < m\ \f (x mod d, x div d) = x\ \x div d < prod_list ds\ \x mod d < m\ image_eqI lessThan_iff low_digits_def mem_Collect_eq) qed then have "bij_betw f ({.. {i. low_digits ds i}) {i. low_digits (d # ds) i}" by (simp add: \inj_on f ({.. {i. low_digits ds i})\ bij_betw_def) then show ?thesis by (simp add: bij_betw_same_card) qed then show ?case unfolding `card {i. low_digits ds i} = m ^ (length ds)` card_cartesian_product using low_digits_def by simp qed lemma card_rows_with_1: "card {i\rows_with_1. irows_with_1. i (\i0\set (digit_encode (nths (Tensor.dims Aw) (Collect even)) i). i0 < r)}" (is "?A = ?B") proof (rule subset_antisym; rule subsetI) fix i assume "i \ ?A" then have "i < dim_row Aw'" "\i0\set (digit_encode (nths (Tensor.dims Aw) {n. even n}) i). i0 < last (butlast rs)" using rows_with_1_def by auto then have "i < prod_list (nths (dims Aw) (Collect even))" using dims_Aw' by linarith then have "digit_encode (nths (dims Aw) (Collect even)) i \ nths (dims Aw) (Collect even)" using digit_encode_valid_index by auto have "\i0\set (digit_encode (nths (Tensor.dims Aw) {n. even n}) i). i0 < r" proof fix i0 assume 1:"i0 \ set (digit_encode (nths (dims Aw) (Collect even)) i)" then obtain k where "k < length (digit_encode (nths (dims Aw) (Collect even)) i)" "digit_encode (nths (dims Aw) (Collect even)) i ! k = i0" by (meson in_set_conv_nth) have "i0 < last (butlast rs)" using \\i0\set (digit_encode (nths (dims Aw) (Collect even)) i). i0 < last (butlast rs)\ 1 by blast have "set (nths (dims Aw) (Collect even)) \ {last rs}" unfolding dims_Aw using subset_eq by fastforce then have "nths (dims Aw) (Collect even) ! k = last rs" using \digit_encode (nths (dims Aw) (Collect even)) i \ nths (dims Aw) (Collect even)\ \k < length (digit_encode (nths (dims Aw) (Collect even)) i)\ nth_mem valid_index_length by auto then have "i0 < last rs" using valid_index_lt \digit_encode (nths (dims Aw) (Collect even)) i ! k = i0\ \digit_encode (nths (dims Aw) (Collect even)) i \ nths (dims Aw) (Collect even)\ \k < length (digit_encode (nths (dims Aw) (Collect even)) i)\ valid_index_length by fastforce then show "i0 < r" unfolding r_def by (simp add: \i0 < last (butlast rs)\) qed then show "i \ ?B" using \i < prod_list (nths (dims Aw) (Collect even))\ by blast next fix i assume "i\?B" then show "i\?A" by (simp add: dims_Aw' r_def rows_with_1_def) qed have 2:"\d. d \ set (nths (Tensor.dims Aw) (Collect even)) \ r \ d" proof - fix d assume "d \ set (nths (Tensor.dims Aw) (Collect even))" then have "d \ set (Tensor.dims Aw)" using in_set_nthsD by fast then have "d = last rs" using dims_Aw by simp then show "r \ d" by (simp add: r_def) qed have 3:"0 < r" unfolding r_def by (metis deep diff_diff_cancel diff_zero dual_order.trans in_set_butlastD last_in_set length_butlast list.size(3) min_def nat_le_linear no_zeros not_numeral_le_zero numeral_le_one_iff rel_simps(3)) have 4: "length (nths (Tensor.dims Aw) (Collect even)) = N_half" unfolding length_nths order_Aw using card_even[of N_half] by (metis (mono_tags, lifting) Collect_cong) then show ?thesis using card_low_digits[of "r" "nths (Tensor.dims Aw) (Collect even)"] 1 2 3 4 by metis qed lemma infinite_rows_with_1: "infinite rows_with_1" proof - define listpr where "listpr = prod_list (nths (Tensor.dims Aw) {n. even n})" have "\i. listpr dvd i \ i \ rows_with_1" proof - fix i assume dvd_i: "listpr dvd i" { fix i0::nat assume "i0\set (digit_encode (nths (Tensor.dims Aw) {n. even n}) i)" then have "i0=0" using digit_encode_0 dvd_i listpr_def by auto then have "i0 < last (butlast rs)" using deep no_zeros by (metis Nitpick.size_list_simp(2) One_nat_def Suc_le_lessD in_set_butlastD last_in_set length_butlast length_tl not_numeral_less_zero numeral_2_eq_2 numeral_3_eq_3 numeral_le_one_iff semiring_norm(70)) } then show "i\rows_with_1" by (simp add: rows_with_1_def) qed have 0:"Tensor.dims Aw = replicate (2 ^ (length rs - 2)) (last rs)" unfolding Aw_def using dims_output_witness[OF _ no_zeros y_valid] using deep by linarith then have "listpr > 0" unfolding listpr_def 0 by (metis "0" deep last_in_set length_greater_0_conv less_le_trans no_zeros dims_Aw'_pow(1) dims_Aw'(1) zero_less_numeral zero_less_power) then have "inj (( * ) listpr)" by (metis injI mult_left_cancel neq0_conv) then show ?thesis using `\i. listpr dvd i \ i \ rows_with_1` by (meson dvd_triv_left image_subset_iff infinite_iff_countable_subset) qed lemma witness_submatrix: "submatrix Aw' rows_with_1 rows_with_1 = 1\<^sub>m (r^N_half)" proof show "dim_row (submatrix Aw' rows_with_1 rows_with_1) = dim_row (1\<^sub>m (r ^ N_half))" unfolding index_one_mat(2) dim_submatrix(1) by (metis (full_types) set_le_in card_rows_with_1) show "dim_col (submatrix Aw' rows_with_1 rows_with_1) = dim_col (1\<^sub>m (r ^ N_half))" by (metis \dim_row (submatrix Aw' rows_with_1 rows_with_1) = dim_row (1\<^sub>m (r ^ N_half))\ dim_submatrix(1) dim_submatrix(2) index_one_mat(2) index_one_mat(3) dims_Aw'_pow) show "\i j. i < dim_row (1\<^sub>m (r ^ N_half)) \ j < dim_col (1\<^sub>m (r ^ N_half)) \ submatrix Aw' rows_with_1 rows_with_1 $$ (i, j) = 1\<^sub>m (r ^ N_half) $$ (i, j)" proof - fix i j assume "i < dim_row (1\<^sub>m (r ^ N_half))" "j < dim_col (1\<^sub>m (r ^ N_half))" then have "i < r ^ N_half" "j < r ^ N_half" by auto then have "i < card {i \ rows_with_1. i < dim_row Aw'}" "j < card {i \ rows_with_1. i < dim_col Aw'}" using card_rows_with_1 dims_Aw'_pow by auto then have "pick rows_with_1 i < dim_row Aw'" "pick rows_with_1 j < dim_col Aw'" using card_le_pick_inf[OF infinite_rows_with_1, of "dim_row Aw'" i] using card_le_pick_inf[OF infinite_rows_with_1, of "dim_col Aw'" j] by force+ have "\i0\set (digit_encode (nths (dims Aw) (Collect even)) (pick rows_with_1 i)). i0 < last (butlast rs)" using infinite_rows_with_1 pick_in_set_inf rows_with_1_def by auto then have "Aw' $$ (pick rows_with_1 i, pick rows_with_1 j) = (if pick rows_with_1 i = pick rows_with_1 j then 1 else 0)" using witness_matricization[OF `pick rows_with_1 i < dim_row Aw'` `pick rows_with_1 j < dim_col Aw'`] by simp then have "submatrix Aw' rows_with_1 rows_with_1 $$ (i, j) = (if pick rows_with_1 i = pick rows_with_1 j then 1 else 0)" using submatrix_index by (metis (no_types, lifting) \dim_col (submatrix Aw' rows_with_1 rows_with_1) = dim_col (1\<^sub>m (r ^ N_half))\ \dim_row (submatrix Aw' rows_with_1 rows_with_1) = dim_row (1\<^sub>m (r ^ N_half))\ \i < dim_row (1\<^sub>m (r ^ N_half))\ \j < r ^ N_half\ dim_submatrix(1) dim_submatrix(2) index_one_mat(3)) then have "submatrix Aw' rows_with_1 rows_with_1 $$ (i, j) = (if i = j then 1 else 0)" using pick_eq_iff_inf[OF infinite_rows_with_1] by auto then show "submatrix Aw' rows_with_1 rows_with_1 $$ (i, j) = 1\<^sub>m (r ^ N_half) $$ (i, j)" by (simp add: \i < r ^ N_half\ \j < r ^ N_half\) qed qed lemma witness_det: "det (submatrix Aw' rows_with_1 rows_with_1) \ 0" unfolding witness_submatrix by simp end (* Examples to show that the locales can be instantiated: *) interpretation example : deep_model_correct_params False "[10,10,10]" unfolding deep_model_correct_params_def by simp interpretation example : deep_model_correct_params_y False "[10,10,10]" 1 unfolding deep_model_correct_params_y_def deep_model_correct_params_y_axioms_def deep_model_correct_params_def by simp end diff --git a/thys/Deep_Learning/DL_Deep_Model_Poly.thy b/thys/Deep_Learning/DL_Deep_Model_Poly.thy --- a/thys/Deep_Learning/DL_Deep_Model_Poly.thy +++ b/thys/Deep_Learning/DL_Deep_Model_Poly.thy @@ -1,363 +1,364 @@ (* Author: Alexander Bentkamp, Universität des Saarlandes *) section \Polynomials representing the Deep Network Model\ theory DL_Deep_Model_Poly imports DL_Deep_Model Polynomials.More_MPoly_Type Jordan_Normal_Form.Determinant begin -definition "polyfun N f = (\p. vars p \ N \ (\x. insertion x p = f x))" +definition polyfun :: "nat set \ ((nat \ 'a::comm_semiring_1) \ 'a) \ bool" + where "polyfun N f = (\p. vars p \ N \ (\x. insertion x p = f x))" lemma polyfunI: "(\P. (\p. vars p \ N \ (\x. insertion x p = f x) \ P) \ P) \ polyfun N f" unfolding polyfun_def by metis lemma polyfun_subset: "N\N' \ polyfun N f \ polyfun N' f" unfolding polyfun_def by blast lemma polyfun_const: "polyfun N (\_. c)" proof - have "\x. insertion x (monom 0 c) = c" using insertion_single by (metis insertion_one monom_one mult.commute mult.right_neutral single_zero) then show ?thesis unfolding polyfun_def by (metis (full_types) empty_iff keys_single single_zero subsetI subset_antisym vars_monom_subset) qed lemma polyfun_add: assumes "polyfun N f" "polyfun N g" shows "polyfun N (\x. f x + g x)" proof - obtain p1 p2 where "vars p1 \ N" "\x. insertion x p1 = f x" "vars p2 \ N" "\x. insertion x p2 = g x" using polyfun_def assms by metis then have "vars (p1 + p2) \ N" "\x. insertion x (p1 + p2) = f x + g x" using vars_add using Un_iff subsetCE subsetI apply blast by (simp add: \\x. insertion x p1 = f x\ \\x. insertion x p2 = g x\ insertion_add) then show ?thesis using polyfun_def by blast qed lemma polyfun_mult: assumes "polyfun N f" "polyfun N g" shows "polyfun N (\x. f x * g x)" proof - obtain p1 p2 where "vars p1 \ N" "\x. insertion x p1 = f x" "vars p2 \ N" "\x. insertion x p2 = g x" using polyfun_def assms by metis then have "vars (p1 * p2) \ N" "\x. insertion x (p1 * p2) = f x * g x" using vars_mult using Un_iff subsetCE subsetI apply blast by (simp add: \\x. insertion x p1 = f x\ \\x. insertion x p2 = g x\ insertion_mult) then show ?thesis using polyfun_def by blast qed lemma polyfun_Sum: assumes "finite I" assumes "\i. i\I \ polyfun N (f i)" shows "polyfun N (\x. \i\I. f i x)" using assms apply (induction I rule:finite_induct) apply (simp add: polyfun_const) using comm_monoid_add_class.sum.insert polyfun_add by fastforce lemma polyfun_Prod: assumes "finite I" assumes "\i. i\I \ polyfun N (f i)" shows "polyfun N (\x. \i\I. f i x)" using assms apply (induction I rule:finite_induct) apply (simp add: polyfun_const) using comm_monoid_add_class.sum.insert polyfun_mult by fastforce lemma polyfun_single: assumes "i\N" shows "polyfun N (\x. x i)" proof - have "\f. insertion f (monom (Poly_Mapping.single i 1) 1) = f i" using insertion_single by simp then show ?thesis unfolding polyfun_def using vars_monom_single[of i 1 1] One_nat_def assms singletonD subset_eq by blast qed lemma polyfun_det: assumes "\x. (A x) \ carrier_mat n n" assumes "\x i j. i j polyfun N (\x. (A x) $$ (i,j))" shows "polyfun N (\x. det (A x))" proof - { fix p assume "p\ {p. p permutes {0..x. x < n \ p x < n" using permutes_in_image by auto then have "polyfun N (\x. \i = 0..i x. A x $$ (i, p i)"] assms by simp then have "polyfun N (\x. signof p * (\i = 0..p x. signof p * (\i = 0..f. extract_matrix (\i. f (i + a)) m n $$ (i,j))" unfolding index_extract_matrix[OF assms] apply (rule polyfun_single) using two_digit_le[OF assms] by simp lemma polyfun_mult_mat_vec: assumes "\x. v x \ carrier_vec n" assumes "\j. j polyfun N (\x. v x $ j)" assumes "\x. A x \ carrier_mat m n" assumes "\i j. i j polyfun N (\x. A x $$ (i,j))" assumes "j < m" shows "polyfun N (\x. ((A x) *\<^sub>v (v x)) $ j)" proof - have "\x. j < dim_row (A x)" using `j < m` assms(3) carrier_matD(1) by force have "\x. n = dim_vec (v x)" using assms(1) carrier_vecD by fastforce { fix i assume "i \ {0..i < dim_vec (v x)\) } then have "polyfun N (\x. row (A x) j $ i * v x $ i)" using polyfun_mult assms(4)[OF `j < m`] assms(2) by fastforce } then show ?thesis unfolding index_mult_mat_vec[OF `\x. j < dim_row (A x)`] scalar_prod_def using polyfun_Sum[of "{0..i x. row (A x) j $ i * v x $ i"] finite_atLeastLessThan[of 0 n] `\x. n = dim_vec (v x)` by simp qed (* The variable a has been inserted here to make the induction work:*) lemma polyfun_evaluate_net_plus_a: assumes "map dim_vec inputs = input_sizes m" assumes "valid_net m" assumes "j < output_size m" shows "polyfun {..f. evaluate_net (insert_weights s m (\i. f (i + a))) inputs $ j)" using assms proof (induction m arbitrary:inputs j a) case (Input) then show ?case unfolding insert_weights.simps evaluate_net.simps using polyfun_const by metis next case (Conv x m) then obtain x1 x2 where "x=(x1,x2)" by fastforce show ?case unfolding `x=(x1,x2)` insert_weights.simps evaluate_net.simps drop_map unfolding list_of_vec_index proof (rule polyfun_mult_mat_vec) { fix f have 1:"valid_net' (insert_weights s m (\i. f (i + x1 * x2)))" using `valid_net (Conv x m)` valid_net.simps by (metis convnet.distinct(1) convnet.distinct(5) convnet.inject(2) remove_insert_weights) have 2:"map dim_vec inputs = input_sizes (insert_weights s m (\i. f (i + x1 * x2)))" using input_sizes_remove_weights remove_insert_weights by (simp add: Conv.prems(1)) have "dim_vec (evaluate_net (insert_weights s m (\i. f (i + x1 * x2))) inputs) = output_size m" using output_size_correct[OF 1 2] using remove_insert_weights by auto then show "evaluate_net (insert_weights s m (\i. f (i + x1 * x2))) inputs \ carrier_vec (output_size m)" using carrier_vec_def by (metis (full_types) mem_Collect_eq) } have "map dim_vec inputs = input_sizes m" by (simp add: Conv.prems(1)) have "valid_net m" using Conv.prems(2) valid_net.cases by fastforce show "\j. j < output_size m \ polyfun {..f. evaluate_net (insert_weights s m (\i. f (i + x1 * x2 + a))) inputs $ j)" unfolding vec_of_list_index count_weights.simps using Conv(1)[OF `map dim_vec inputs = input_sizes m` `valid_net m`, of _ "x1 * x2 + a"] unfolding semigroup_add_class.add.assoc ab_semigroup_add_class.add.commute[of "x1 * x2" a] by blast have "output_size m = x2" using Conv.prems(2) \x = (x1, x2)\ valid_net.cases by fastforce show "\f. extract_matrix (\i. f (i + a)) x1 x2 \ carrier_mat x1 (output_size m)" unfolding `output_size m = x2` using dim_extract_matrix using carrier_matI by (metis (no_types, lifting)) show "\i j. i < x1 \ j < output_size m \ polyfun {..f. extract_matrix (\i. f (i + a)) x1 x2 $$ (i, j))" unfolding `output_size m = x2` count_weights.simps using polyfun_extract_matrix[of _ x1 _ x2 a "count_weights s m"] by blast show "j < x1" using Conv.prems(3) \x = (x1, x2)\ by auto qed next case (Pool m1 m2 inputs j a) have A2:"\f. map dim_vec (take (length (input_sizes (insert_weights s m1 (\i. f (i + a))))) inputs) = input_sizes m1" by (metis Pool.prems(1) append_eq_conv_conj input_sizes.simps(3) input_sizes_remove_weights remove_insert_weights take_map) have B2:"\f. map dim_vec (drop (length (input_sizes (insert_weights s m1 (\i. f (i + a))))) inputs) = input_sizes m2" using Pool.prems(1) append_eq_conv_conj input_sizes.simps(3) input_sizes_remove_weights remove_insert_weights by (metis drop_map) have A3:"valid_net m1" and B3:"valid_net m2" using `valid_net (Pool m1 m2)` valid_net.simps by blast+ have "output_size (Pool m1 m2) = output_size m2" unfolding output_size.simps using `valid_net (Pool m1 m2)` "valid_net.cases" by fastforce then have A4:"j < output_size m1" and B4:"j < output_size m2" using `j < output_size (Pool m1 m2)` by simp_all let ?net1 = "\f. evaluate_net (insert_weights s m1 (\i. f (i + a))) (take (length (input_sizes (insert_weights s m1 (\i. f (i + a))))) inputs)" let ?net2 = "\f. evaluate_net (insert_weights s m2 (if s then \i. f (i + a) else (\i. f (i + count_weights s m1 + a)))) (drop (length (input_sizes (insert_weights s m1 (\i. f (i + a))))) inputs)" have length1: "\f. output_size m1 = dim_vec (?net1 f)" by (metis A2 A3 input_sizes_remove_weights output_size_correct remove_insert_weights) then have jlength1:"\f. j < dim_vec (?net1 f)" using A4 by metis have length2: "\f. output_size m2 = dim_vec (?net2 f)" by (metis B2 B3 input_sizes_remove_weights output_size_correct remove_insert_weights) then have jlength2:"\f. j < dim_vec (?net2 f)" using B4 by metis have cong1:"\xf. (\f. evaluate_net (insert_weights s m1 (\i. f (i + a))) (take (length (input_sizes (insert_weights s m1 (\i. xf (i + a))))) inputs) $ j) = (\f. ?net1 f $ j)" using input_sizes_remove_weights remove_insert_weights by auto have cong2:"\xf. (\f. evaluate_net (insert_weights s m2 (\i. f (i + (a + (if s then 0 else count_weights s m1))))) (drop (length (input_sizes (insert_weights s m1 (\i. xf (i + a))))) inputs) $ j) = (\f. ?net2 f $ j)" unfolding semigroup_add_class.add.assoc[symmetric] ab_semigroup_add_class.add.commute[of a "if s then 0 else count_weights s m1"] using input_sizes_remove_weights remove_insert_weights by auto show ?case unfolding insert_weights.simps evaluate_net.simps count_weights.simps unfolding index_component_mult[OF jlength1 jlength2] apply (rule polyfun_mult) using Pool.IH(1)[OF A2 A3 A4, of a, unfolded cong1] apply (simp add:polyfun_subset[of "{..f. evaluate_net (insert_weights s m f) inputs $ j)" using polyfun_evaluate_net_plus_a[where a=0, OF assms] by simp lemma polyfun_tensors_from_net: assumes "valid_net m" assumes "is \ input_sizes m" assumes "j < output_size m" shows "polyfun {..f. Tensor.lookup (tensors_from_net (insert_weights s m f) $ j) is)" proof - have 1:"\f. valid_net' (insert_weights s m f)" by (simp add: assms(1) remove_insert_weights) have input_sizes:"\f. input_sizes (insert_weights s m f) = input_sizes m" unfolding input_sizes_remove_weights by (simp add: remove_insert_weights) have 2:"\f. is \ input_sizes (insert_weights s m f)" unfolding input_sizes using assms(2) by blast have 3:"\f. j < output_size' (insert_weights s m f)" by (simp add: assms(3) remove_insert_weights) have "\f1 f2. base_input (insert_weights s m f1) is = base_input (insert_weights s m f2) is" unfolding base_input_def by (simp add: input_sizes) then have "\xf. (\f. evaluate_net (insert_weights s m f) (base_input (insert_weights s m xf) is) $ j) = (\f. evaluate_net (insert_weights s m f) (base_input (insert_weights s m f) is) $ j)" by metis then show ?thesis unfolding lookup_tensors_from_net[OF 1 2 3] using polyfun_evaluate_net[OF base_input_length[OF 2, unfolded input_sizes, symmetric] assms(1) assms(3), of s] by simp qed lemma polyfun_matricize: assumes "\x. dims (T x) = ds" assumes "\is. is \ ds \ polyfun N (\x. Tensor.lookup (T x) is)" assumes "\x. dim_row (matricize I (T x)) = nr" assumes "\x. dim_col (matricize I (T x)) = nc" assumes "i < nr" assumes "j < nc" shows "polyfun N (\x. matricize I (T x) $$ (i,j))" proof - let ?weave = "\ x. (weave I (digit_encode (nths ds I ) i) (digit_encode (nths ds (-I )) j))" have 1:"\x. matricize I (T x) $$ (i,j) = Tensor.lookup (T x) (?weave x)" unfolding matricize_def by (metis (no_types, lifting) assms(1) assms(3) assms(4) assms(5) assms(6) case_prod_conv dim_col_mat(1) dim_row_mat(1) index_mat(1) matricize_def) have "\x. ?weave x \ ds" using valid_index_weave(1) assms(2) digit_encode_valid_index dim_row_mat(1) matricize_def using assms digit_encode_valid_index matricize_def by (metis dim_col_mat(1)) then have "polyfun N (\x. Tensor.lookup (T x) (?weave x))" using assms(2) by simp then show ?thesis unfolding 1 using assms(1) by blast qed lemma "(\ (a::nat) < b) = (a \ b)" by (metis not_le) lemma polyfun_submatrix: assumes "\x. (A x) \ carrier_mat m n" assumes "\x i j. i j polyfun N (\x. (A x) $$ (i,j))" assumes "i < card {i. i < m \ i \ I}" assumes "j < card {j. j < n \ j \ J}" assumes "infinite I" "infinite J" shows "polyfun N (\x. (submatrix (A x) I J) $$ (i,j))" proof - have 1:"\x. (submatrix (A x) I J) $$ (i,j) = (A x) $$ (pick I i, pick J j)" using submatrix_index by (metis (no_types, lifting) Collect_cong assms(1) assms(3) assms(4) carrier_matD(1) carrier_matD(2)) have "pick I i < m" "pick J j < n" using card_le_pick_inf[OF `infinite I`] card_le_pick_inf[OF `infinite J`] `i < card {i. i < m \ i \ I}`[unfolded set_le_in] `j < card {j. j < n \ j \ J}`[unfolded set_le_in] not_less by metis+ then show ?thesis unfolding 1 by (simp add: assms(2)) qed context deep_model_correct_params_y begin definition witness_submatrix where "witness_submatrix f = submatrix (A' f) rows_with_1 rows_with_1" lemma polyfun_tensor_deep_model: assumes "is \ input_sizes (deep_model_l rs)" shows "polyfun {..f. Tensor.lookup (tensors_from_net (insert_weights shared_weights (deep_model_l rs) f) $ y) is)" proof - have 1:"\f. remove_weights (insert_weights shared_weights (deep_model_l rs) f) = deep_model_l rs" using remove_insert_weights by metis then have "y < output_size ( deep_model_l rs)" using valid_deep_model y_valid length_output_deep_model by force have 0:"{..f. A' f $$ (i,j))" proof - have 0:"y < output_size ( deep_model_l rs )" using valid_deep_model y_valid length_output_deep_model by force have 1:"\f. remove_weights (insert_weights shared_weights (deep_model_l rs) f) = deep_model_l rs" using remove_insert_weights by metis have 2:"(\f is. is \ replicate (2 * N_half) (last rs) \ polyfun {..x. Tensor.lookup (A x) is))" unfolding A_def using polyfun_tensor_deep_model[unfolded input_sizes_deep_model] 0 by blast show ?thesis unfolding A'_def A_def apply (rule polyfun_matricize) using dims_tensor_deep_model[OF 1] 2[unfolded A_def] using dims_A'_pow[unfolded A'_def A_def] `i<(last rs) ^ N_half` `j<(last rs) ^ N_half` by auto qed lemma polyfun_submatrix_deep_model: assumes "i < r ^ N_half" assumes "j < r ^ N_half" shows "polyfun {..f. witness_submatrix f $$ (i,j))" unfolding witness_submatrix_def proof (rule polyfun_submatrix) have 1:"\f. remove_weights (insert_weights shared_weights (deep_model_l rs) f) = deep_model_l rs" using remove_insert_weights by metis show "\f. A' f \ carrier_mat ((last rs) ^ N_half) ((last rs) ^ N_half)" using "1" dims_A'_pow using weight_space_dim_def by auto show "\f i j. i < last rs ^ N_half \ j < last rs ^ N_half \ polyfun {..f. A' f $$ (i, j))" using polyfun_matrix_deep_model weight_space_dim_def by force show "i < card {i. i < last rs ^ N_half \ i \ rows_with_1}" using assms(1) card_rows_with_1 dims_Aw'_pow set_le_in by metis show "j < card {i. i < last rs ^ N_half \ i \ rows_with_1}" using assms(2) card_rows_with_1 dims_Aw'_pow set_le_in by metis show "infinite rows_with_1" "infinite rows_with_1" by (simp_all add: infinite_rows_with_1) qed lemma polyfun_det_deep_model: shows "polyfun {..f. det (witness_submatrix f))" proof (rule polyfun_det) fix f have "remove_weights (insert_weights shared_weights (deep_model_l rs) f) = deep_model_l rs" using remove_insert_weights by metis show "witness_submatrix f \ carrier_mat (r ^ N_half) (r ^ N_half)" unfolding witness_submatrix_def apply (rule carrier_matI) unfolding dim_submatrix[unfolded set_le_in] unfolding dims_A'_pow[unfolded weight_space_dim_def] using card_rows_with_1 dims_Aw'_pow by simp_all show "\i j. i < r ^ N_half \ j < r ^ N_half \ polyfun {..f. witness_submatrix f $$ (i, j))" using polyfun_submatrix_deep_model by blast qed end end