diff --git a/thys/Deep_Learning/DL_Deep_Model.thy b/thys/Deep_Learning/DL_Deep_Model.thy --- a/thys/Deep_Learning/DL_Deep_Model.thy +++ b/thys/Deep_Learning/DL_Deep_Model.thy @@ -1,957 +1,951 @@ (* Author: Alexander Bentkamp, Universität des Saarlandes *) section \Deep Network Model\ theory DL_Deep_Model imports DL_Network Tensor_Matricization Jordan_Normal_Form.DL_Submatrix DL_Concrete_Matrices DL_Missing_Finite_Set Jordan_Normal_Form.DL_Missing_Sublist Jordan_Normal_Form.Determinant begin hide_const(open) Polynomial.order fun deep_model and deep_model' where "deep_model' Y [] = Input Y" | "deep_model' Y (r # rs) = Pool (deep_model Y r rs) (deep_model Y r rs)" | "deep_model Y r rs = Conv (Y,r) (deep_model' r rs)" abbreviation "deep_model'_l rs == deep_model' (rs!0) (tl rs)" abbreviation "deep_model_l rs == deep_model (rs!0) (rs!1) (tl (tl rs))" lemma valid_deep_model: "valid_net (deep_model Y r rs)" apply (induction rs arbitrary: Y r) apply (simp add: valid_net.intros(1) valid_net.intros(2)) using valid_net.intros(2) valid_net.intros(3) by auto lemma valid_deep_model': "valid_net (deep_model' r rs)" apply (induction rs arbitrary: r) apply (simp add: valid_net.intros(1)) by (metis deep_model'.elims deep_model'.simps(2) deep_model.elims output_size.simps valid_net.simps) lemma input_sizes_deep_model': assumes "length rs \ 1" shows "input_sizes (deep_model'_l rs) = replicate (2^(length rs - 1)) (last rs)" using assms proof (induction "butlast rs" arbitrary:rs) case Nil then have "rs = [rs!0]" by (metis One_nat_def diff_diff_cancel diff_zero length_0_conv length_Suc_conv length_butlast nth_Cons_0) then have "input_sizes (deep_model'_l rs) = [last rs]" by (metis deep_model'.simps(1) input_sizes.simps(1) last.simps list.sel(3)) then show "input_sizes (deep_model'_l rs) = replicate (2 ^ (length rs - 1)) (last rs)" by (metis One_nat_def \[] = butlast rs\ empty_replicate length_butlast list.size(3) power_0 replicate.simps(2)) next case (Cons r rs' rs) then have IH: "input_sizes (deep_model'_l (tl rs)) = replicate (2 ^ (length (tl rs) - 1)) (last rs)" by (metis (no_types, lifting) One_nat_def butlast_tl diff_is_0_eq' last_tl length_Cons length_butlast length_tl list.sel(3) list.size(3) nat_le_linear not_one_le_zero) have "rs = r # (tl rs)" by (metis Cons.hyps(2) Cons.prems One_nat_def append_Cons append_butlast_last_id length_greater_0_conv less_le_trans list.sel(3) zero_less_Suc) then have "deep_model'_l rs = Pool (deep_model_l rs) (deep_model_l rs)" by (metis Cons.hyps(2) One_nat_def butlast.simps(2) deep_model'.elims list.sel(3) list.simps(3) nth_Cons_0 nth_Cons_Suc) then have "input_sizes (deep_model'_l rs) = input_sizes (deep_model_l rs) @ input_sizes (deep_model_l rs)" using input_sizes.simps(3) by metis also have "... = input_sizes (deep_model'_l (tl rs)) @ input_sizes (deep_model'_l (tl rs))" by (metis (no_types, lifting) Cons.hyps(2) One_nat_def deep_model.elims input_sizes.simps(2) length_Cons length_butlast length_greater_0_conv length_tl list.sel(2) list.sel(3) list.size(3) nth_tl one_neq_zero) also have "... = replicate (2 ^ (length (tl rs) - 1)) (last rs) @ replicate (2 ^ (length (tl rs) - 1)) (last rs)" using IH by auto also have "... = replicate (2 ^ (length rs - 1)) (last rs)" using replicate_add[of "2 ^ (length (tl rs) - 1)" "2 ^ (length (tl rs) - 1)" "last rs"] by (metis Cons.hyps(2) One_nat_def butlast_tl length_butlast list.sel(3) list.size(4) mult_2_right power_add power_one_right) finally show ?case by auto qed lemma input_sizes_deep_model: assumes "length rs \ 2" shows "input_sizes (deep_model_l rs) = replicate (2^(length rs - 2)) (last rs)" proof - have "input_sizes (deep_model_l rs) = input_sizes (deep_model'_l (tl rs))" by (metis One_nat_def Suc_1 assms hd_Cons_tl deep_model.elims input_sizes.simps(2) length_Cons length_greater_0_conv lessI linorder_not_le list.size(3) not_numeral_le_zero nth_tl) also have "... = replicate (2^(length rs - 2)) (last rs)" using input_sizes_deep_model' by (metis (no_types, lifting) One_nat_def Suc_1 Suc_eq_plus1 assms diff_diff_left hd_Cons_tl last_tl length_Cons length_tl linorder_not_le list.size(3) not_less_eq not_numeral_le_zero numeral_le_one_iff semiring_norm(69)) finally show ?thesis by auto qed lemma evaluate_net_Conv_id: assumes "valid_net' m" and "input_sizes m = map dim_vec input" and "j Tensor.dims (?a$i)" then have "is \ input_sizes m" using `Tensor.dims (?a$i) = input_sizes m` by auto have "valid_net' Convm" by (simp add: assms eye_matrix_dim valid_net.intros(2) Convm_def) have "base_input m is = base_input Convm is" by (simp add: Convm_def base_input_def) have "i < output_size' Convm" unfolding Convm_def remove_weights.simps output_size.simps eye_matrix_dim using assms by metis have "is \ input_sizes (Conv (eye_matrix nr (output_size' m)) m)" by (metis \is \ input_sizes m\ input_sizes.simps(2)) then have f1: "lookup (tensors_from_net (Conv (eye_matrix nr (output_size' m)) m) $ i) is = evaluate_net (Conv (eye_matrix nr (output_size' m)) m) (base_input (Conv (eye_matrix nr (output_size' m)) m) is) $ i" using Convm_def \i < output_size' Convm\ \valid_net' Convm\ lookup_tensors_from_net by blast have "lookup (tensor0 (input_sizes m)) is = (0::real)" by (meson \is \ input_sizes m\ lookup_tensor0) then show "Tensor.lookup (?a $ i) is = Tensor.lookup ?b is" using Convm_def \base_input m is = base_input Convm is\ \is \ input_sizes m\ assms(1) assms(2) base_input_length evaluate_net_Conv_id f1 lookup_tensors_from_net by auto qed lemma evaluate_net_Conv_copy_first: assumes "valid_net' m" and "input_sizes m = map dim_vec input" and "j 0" shows "evaluate_net (Conv (copy_first_matrix nr (output_size' m)) m) input $ j = evaluate_net m input $ 0" unfolding evaluate_net.simps output_size_correct[OF assms(1) assms(2)[symmetric]] using mult_copy_first_matrix[OF `joutput_size' m = dim_vec (evaluate_net m input)\ assms(4)) lemma tensors_from_net_Conv_copy_first: assumes "valid_net' m" and "i 0" shows "tensors_from_net (Conv (copy_first_matrix nr (output_size' m)) m) $ i = tensors_from_net m $ 0" (is "?a $ i = ?b") proof (rule tensor_lookup_eqI) have "Tensor.dims (?a$i) = input_sizes m" by (metis assms(1) assms(2) copy_first_matrix_dim(1) copy_first_matrix_dim(2) dims_tensors_from_net input_sizes.simps(2) output_size.simps(2) output_size_correct_tensors remove_weights.simps(2) valid_net.intros(2) vec_setI) moreover have "Tensor.dims (?b) = input_sizes m" using dims_tensors_from_net output_size_correct_tensors[OF assms(1)] using assms(3) by (simp add: vec_setI) ultimately show "Tensor.dims (?a$i) = Tensor.dims (?b)" by auto define Convm where "Convm = Conv (copy_first_matrix nr (output_size' m)) m" fix "is" assume "is \ Tensor.dims (?a$i)" then have "is \ input_sizes m" using `Tensor.dims (?a$i) = input_sizes m` by auto have "valid_net' Convm" by (simp add: assms copy_first_matrix_dim valid_net.intros(2) Convm_def) have "base_input m is = base_input Convm is" by (simp add: Convm_def base_input_def) have "i < output_size' Convm" unfolding Convm_def remove_weights.simps output_size.simps copy_first_matrix_dim using assms by metis show "Tensor.lookup (?a $ i) is = Tensor.lookup ?b is" by (metis Convm_def \base_input m is = base_input Convm is\ \i < output_size' Convm\ \is \ input_sizes m\ \valid_net' Convm\ assms(1) assms(2) assms(3) base_input_length evaluate_net_Conv_copy_first input_sizes.simps(2) lookup_tensors_from_net) qed lemma evaluate_net_Conv_all1: assumes "valid_net' m" and "input_sizes m = map dim_vec input" and "i Tensor.dims (?a $ i)" then have "is \ input_sizes m" using \i < dim_vec ?a\ dims_tensors_from_net input_sizes.simps(2) by (metis vec_setI) then have "is \ input_sizes Convm" by (simp add: Convm_def) have "valid_net' Convm" by (simp add: Convm_def assms all1_matrix_dim valid_net.intros(2)) have "i< output_size' Convm" using Convm_def \i < dim_vec ?a\ \valid_net' Convm\ output_size_correct_tensors by presburger have "base_input Convm is = base_input m is" unfolding base_input_def Convm_def input_sizes.simps by metis have "Tensor.lookup (?a $ i) is = evaluate_net Convm (base_input Convm is) $ i" using lookup_tensors_from_net[OF `valid_net' Convm` `is \ input_sizes Convm` `i< output_size' Convm`] by (metis Convm_def ) also have "... = monoid_add_class.sum_list (list_of_vec (evaluate_net m (base_input Convm is)))" using evaluate_net_Conv_all1 Convm_def \is \ input_sizes Convm\ assms base_input_length \i < nr\ by simp also have "... = monoid_add_class.sum_list (list_of_vec (map_vec (\A. lookup A is)(tensors_from_net m)))" unfolding `base_input Convm is = base_input m is` using lookup_tensors_from_net[OF `valid_net' m` `is \ input_sizes m`] base_input_length[OF \is \ input_sizes m\] output_size_correct[OF assms(1)] output_size_correct_tensors[OF assms(1)] eq_vecI[of "evaluate_net m (base_input m is)" "map_vec (\A. lookup A is) (tensors_from_net m)"] index_map_vec(1) index_map_vec(2) by force also have "... = monoid_add_class.sum_list (map (\A. lookup A is) (list_of_vec (tensors_from_net m)))" using eq_vecI[of "vec_of_list (list_of_vec (map_vec (\A. lookup A is)(tensors_from_net m)))" "vec_of_list (map (\A. lookup A is) (list_of_vec (tensors_from_net m)))"] dim_vec_of_list nth_list_of_vec length_map list_vec nth_map index_map_vec(1) index_map_vec(2) vec_list by (metis (no_types, lifting)) also have "... = Tensor.lookup ?b is" using dims_tensors_from_net set_list_of_vec using lookup_listsum[OF `is \ input_sizes m`, of "list_of_vec (tensors_from_net m)"] by metis finally show "Tensor.lookup (?a $ i) is = Tensor.lookup ?b is" by blast qed fun witness and witness' where "witness' Y [] = Input Y" | "witness' Y (r # rs) = Pool (witness Y r rs) (witness Y r rs)" | "witness Y r rs = Conv ((if length rs = 0 then eye_matrix else (if length rs = 1 then all1_matrix else copy_first_matrix)) Y r) (witness' r rs)" abbreviation "witness_l rs == witness (rs!0) (rs!1) (tl (tl rs))" abbreviation "witness'_l rs == witness' (rs!0) (tl rs)" lemma witness_is_deep_model: "remove_weights (witness Y r rs) = deep_model Y r rs" proof (induction rs arbitrary: Y r) case Nil then show ?case unfolding witness.simps witness'.simps deep_model.simps deep_model'.simps by (simp add: eye_matrix_dim) next case (Cons r' rs Y r) have "dim_row ((if length (r' # rs) = 0 then eye_matrix else (if length (r' # rs) = 1 then all1_matrix else copy_first_matrix)) Y r) = Y" "dim_col ((if length (r' # rs) = 0 then eye_matrix else (if length (r' # rs) = 1 then all1_matrix else copy_first_matrix)) Y r) = r" by (simp_all add: all1_matrix_dim copy_first_matrix_dim) then show ?case unfolding witness.simps unfolding witness'.simps unfolding remove_weights.simps using Cons by simp qed lemma witness'_is_deep_model: "remove_weights (witness' Y rs) = deep_model' Y rs" proof (induction rs arbitrary: Y) case Nil then show ?case unfolding witness.simps witness'.simps deep_model.simps deep_model'.simps by (simp add: eye_matrix_dim) next case (Cons r rs Y) have "dim_row ((if length rs = 0 then eye_matrix else (if length rs = 1 then all1_matrix else copy_first_matrix)) Y r) = Y" "dim_col ((if length rs = 0 then eye_matrix else (if length rs = 1 then all1_matrix else copy_first_matrix)) Y r) = r" by (simp_all add: all1_matrix_dim copy_first_matrix_dim eye_matrix_dim) then show ?case unfolding witness'.simps unfolding witness.simps unfolding remove_weights.simps using Cons by simp qed lemma witness_valid: "valid_net' (witness Y r rs)" using valid_deep_model witness_is_deep_model by auto lemma witness'_valid: "valid_net' (witness' Y rs)" using valid_deep_model' witness'_is_deep_model by auto lemma shared_weight_net_witness: "shared_weight_net (witness Y r rs)" proof (induction rs arbitrary:Y r) case Nil then show ?case unfolding witness.simps witness'.simps by (simp add: shared_weight_net_Conv shared_weight_net_Input) next case (Cons a rs) then show ?case unfolding witness.simps witness'.simps by (simp add: shared_weight_net_Conv shared_weight_net_Input shared_weight_net_Pool) qed lemma witness_l0': "witness' Y [M] = (Pool (Conv (eye_matrix Y M) (Input M)) (Conv (eye_matrix Y M) (Input M)) )" unfolding witness'.simps witness.simps by simp lemma witness_l1: "witness Y r0 [M] = Conv (all1_matrix Y r0) (witness' r0 [M])" unfolding witness'.simps by simp lemma tensors_ht_l0: assumes "j unit_vec M j = tensor_from_lookup [M,M] (\is. if is=[j,j] then 1 else 0)" (is "?A=?B") proof (rule tensor_lookup_eqI) show "Tensor.dims ?A = Tensor.dims ?B" by (metis append_Cons self_append_conv2 dims_unit_vec dims_tensor_prod dims_tensor_from_lookup) fix "is" assume is_valid:"is \ Tensor.dims (unit_vec M j \ unit_vec M j)" then have "is \ [M,M]" by (metis append_Cons self_append_conv2 dims_unit_vec dims_tensor_prod) then obtain i1 i2 where is_split: "is = [i1, i2]" "i1 < M" "i2 < M" using list.distinct(1) by blast then have "[i1] \ Tensor.dims (unit_vec M j)" "[i2] \ Tensor.dims (unit_vec M j)" by (simp_all add: valid_index.Cons valid_index.Nil dims_unit_vec) have "is = [i1] @ [i2]" by (simp add: is_split(1)) show "Tensor.lookup ?A is = Tensor.lookup ?B is" unfolding `is = [i1] @ [i2]` lookup_tensor_prod[OF `[i1] \ Tensor.dims (unit_vec M j)` `[i2] \ Tensor.dims (unit_vec M j)`] lookup_tensor_from_lookup[OF \is \ [M, M]\, unfolded `is = [i1] @ [i2]`] lookup_unit_vec[OF \i1 < M\] lookup_unit_vec[OF \i2 < M\] by fastforce qed lemma tensors_ht_l0': assumes "j unit_vec M j else tensor0 [M,M])" (is "_ = ?b") proof - have "valid_net' (Conv (eye_matrix r0 M) (Input M))" by (metis convnet.inject(3) list.discI witness'.elims witness_l0' witness_valid) have j_le:"j < dim_vec (tensors_from_net (Conv (eye_matrix r0 M) (Input M)))" using output_size_correct_tensors[OF `valid_net' (Conv (eye_matrix r0 M) (Input M))`, unfolded remove_weights.simps output_size.simps eye_matrix_dim] assms by simp show ?thesis unfolding tensors_from_net.simps(3) witness_l0' index_component_mult[OF j_le j_le] tensors_ht_l0[OF assms] by auto qed lemma lookup_tensors_ht_l0': assumes "j [M,M]" shows "(Tensor.lookup (tensors_from_net (witness' r0 [M]) $ j)) is = (if is=[j,j] then 1 else 0)" proof (cases "jj [j, j]" using assms(2) using list.distinct(1) nth_Cons_0 valid_index.simps by blast show ?thesis unfolding tensors_ht_l0'[OF assms(1)] tensor_prod_unit_vec using `\j [j, j]`) qed lemma lookup_tensors_ht_l1: assumes "j < r1" and "is \ [M,M]" shows "Tensor.lookup (tensors_from_net (witness r1 r0 [M]) $ j) is = (if is!0 = is!1 \ is!0output_size' (witness' r0 [M]) = r0\ witness_l0'_valid output_size_correct_tensors by fastforce have all0_but1:"\i. i\is!0 \ i Tensor.lookup (tensors_from_net (witness' r0 [M]) $ i) is = 0" using lookup_tensors_ht_l0' \is \ [M, M]\ by auto have "tensors_from_net (witness r1 r0 [M]) $ j = Tensor_Plus.listsum [M,M] (list_of_vec (tensors_from_net (witness' r0 [M])))" unfolding witness_l1 using tensors_from_net_Conv_all1[OF witness_l0'_valid assms(1)] witness_l0' `output_size' (witness' r0 [M]) = r0` by simp then have "Tensor.lookup (tensors_from_net (witness r1 r0 [M]) $ j) is = monoid_add_class.sum_list (map (\A. Tensor.lookup A is) (list_of_vec (tensors_from_net (witness' r0 [M]))))" using lookup_listsum[OF `is \ [M, M]`] \input_sizes (witness' r0 [M]) = [M, M]\ dims_tensors_from_net by (metis set_list_of_vec) also have "... = monoid_add_class.sum_list (map (\i. lookup (tensors_from_net (witness' r0 [M]) $ i) is) [0..A. Tensor.lookup A is)" "\i. (tensors_from_net (witness' r0 [M]) $ i)" "[0..i is!0 {0..i {0..i. (Tensor.lookup (tensors_from_net (witness' r0 [M])$i) is)"] using all0_but1 atLeast0LessThan by force then show ?thesis using lookup_tensors_ht_l0' \is ! 0 < r0\ \is \ [M, M]\ by fastforce next case False then show ?thesis using all0_but1 atLeast0LessThan sum.neutral by force qed finally show ?thesis by auto qed lemma length_output_deep_model: assumes "remove_weights m = deep_model_l rs" shows "dim_vec (tensors_from_net m) = rs ! 0" using output_size_correct_tensors valid_deep_model deep_model.elims output_size.simps(2) by (metis assms) lemma length_output_deep_model': assumes "remove_weights m = deep_model'_l rs" shows "dim_vec (tensors_from_net m) = rs ! 0" using output_size_correct_tensors valid_deep_model' deep_model'.elims output_size.simps by (metis assms deep_model.elims) lemma length_output_witness: "dim_vec (tensors_from_net (witness_l rs)) = rs ! 0" using length_output_deep_model witness_is_deep_model by blast lemma length_output_witness': "dim_vec (tensors_from_net (witness'_l rs)) = rs ! 0" using length_output_deep_model' witness'_is_deep_model by blast lemma dims_output_deep_model: assumes "length rs \ 2" and "\r. r\set rs \ r > 0" and "j < rs!0" and "remove_weights m = deep_model_l rs" shows "Tensor.dims (tensors_from_net m $ j) = replicate (2^(length rs - 2)) (last rs)" using dims_tensors_from_net input_sizes_deep_model[OF assms(1)] output_size_correct_tensors valid_deep_model assms(3) assms(4) input_sizes_remove_weights length_output_witness witness_is_deep_model by (metis vec_setI) lemma dims_output_witness: assumes "length rs \ 2" and "\r. r\set rs \ r > 0" and "j < rs!0" shows "Tensor.dims (tensors_from_net (witness_l rs) $ j) = replicate (2^(length rs - 2)) (last rs)" using dims_output_deep_model witness_is_deep_model assms by blast lemma dims_output_deep_model': assumes "length rs \ 1" and "\r. r\set rs \ r > 0" and "j < rs!0" and "remove_weights m = deep_model'_l rs" shows "Tensor.dims (tensors_from_net m $ j) = replicate (2^(length rs - 1)) (last rs)" proof - have "dim_vec (tensors_from_net m) > j" using length_output_deep_model' `remove_weights m = deep_model'_l rs` `j < rs!0` by auto then have "Tensor.dims (tensors_from_net m $ j) = input_sizes m" using dims_tensors_from_net[of _ m] output_size_correct_tensors vec_setI by metis then show ?thesis using assms(1) input_sizes_deep_model' input_sizes_remove_weights[of m, unfolded `remove_weights m = deep_model'_l rs`] by auto qed lemma dims_output_witness': assumes "length rs \ 1" and "\r. r\set rs \ r > 0" and "j < rs!0" shows "Tensor.dims (tensors_from_net (witness'_l rs) $ j) = replicate (2^(length rs - 1)) (last rs)" using dims_output_deep_model' assms witness'_is_deep_model by blast abbreviation "ten2mat == matricize {n. even n}" abbreviation "mat2ten == dematricize {n. even n}" locale deep_model_correct_params = fixes shared_weights::bool fixes rs::"nat list" assumes deep:"length rs \ 3" and no_zeros:"\r. r\set rs \ 0 < r" begin definition "r = min (last rs) (last (butlast rs))" definition "N_half = 2^(length rs - 3)" definition "weight_space_dim = count_weights shared_weights (deep_model_l rs)" end locale deep_model_correct_params_y = deep_model_correct_params + fixes y::nat assumes y_valid:"y < rs ! 0" begin definition "A ws = tensors_from_net (insert_weights shared_weights (deep_model_l rs) ws) $ y" definition "A' ws = ten2mat (A ws)" lemma dims_tensor_deep_model: assumes "remove_weights m = deep_model_l rs" shows "dims (tensors_from_net m $ y) = replicate (2 * N_half) (last rs)" proof - have "dims (tensors_from_net m $ y) = replicate (2 ^ (length rs - 2)) (last rs)" using dims_output_deep_model[OF _ no_zeros y_valid assms] using less_imp_le_nat Suc_le_lessD deep numeral_3_eq_3 by auto then show ?thesis using N_half_def by (metis One_nat_def Suc_1 Suc_eq_plus1 Suc_le_lessD deep diff_diff_left less_numeral_extra(3) numeral_3_eq_3 realpow_num_eq_if zero_less_diff) qed lemma order_tensor_deep_model: assumes "remove_weights m = deep_model_l rs" shows "order (tensors_from_net m $ y) = 2 * N_half" using dims_tensor_deep_model by (simp add: assms) lemma dims_A: shows "Tensor.dims (A ws) = replicate (2 * N_half) (last rs)" unfolding A_def using dims_tensor_deep_model remove_insert_weights by blast lemma order_A: shows "order (A ws) = 2 * N_half" using dims_A length_replicate by auto lemma dims_A': shows "dim_row (A' ws) = prod_list (nths (Tensor.dims (A ws)) {n. even n})" and "dim_col (A' ws) = prod_list (nths (Tensor.dims (A ws)) {n. odd n})" unfolding A'_def matricize_def by (simp_all add: A_def Collect_neg_eq) lemma dims_A'_pow: shows "dim_row (A' ws) = (last rs) ^ N_half" "dim_col (A' ws) = (last rs) ^ N_half" unfolding dims_A' dims_A nths_replicate set_le_in card_even card_odd prod_list_replicate by simp_all definition "Aw = tensors_from_net (witness_l rs) $ y" definition "Aw' = ten2mat Aw" -definition "witness_weights = (SOME ws. witness_l rs = insert_weights shared_weights (deep_model_l rs) ws)" +definition "witness_weights = extract_weights shared_weights (witness_l rs)" lemma witness_weights:"witness_l rs = insert_weights shared_weights (deep_model_l rs) witness_weights" -proof - - have 0:"\x. witness_l rs = insert_weights shared_weights (deep_model_l rs) x" - unfolding weight_space_dim_def using shared_weight_net_witness insert_extract_weights_cong_shared - insert_extract_weights_cong_unshared witness_is_deep_model by (metis (full_types)) - show "witness_l rs = insert_weights shared_weights (deep_model_l rs) witness_weights" - unfolding witness_weights_def using someI_ex[OF 0] by blast -qed + by (metis (full_types) insert_extract_weights_cong_shared insert_extract_weights_cong_unshared shared_weight_net_witness witness_is_deep_model witness_weights_def) lemma Aw_def': "Aw = A witness_weights" unfolding Aw_def A_def using witness_weights by auto lemma Aw'_def': "Aw' = A' witness_weights" unfolding Aw'_def A'_def Aw_def' by auto lemma dims_Aw: "Tensor.dims Aw = replicate (2 * N_half) (last rs)" unfolding Aw_def' using dims_A by auto lemma order_Aw: "order Aw = 2 * N_half" unfolding Aw_def' using order_A by auto lemma dims_Aw': "dim_row Aw' = prod_list (nths (Tensor.dims Aw) {n. even n})" "dim_col Aw' = prod_list (nths (Tensor.dims Aw) {n. odd n})" unfolding Aw'_def' Aw_def' using dims_A' by auto lemma dims_Aw'_pow: "dim_row Aw' = (last rs) ^ N_half" "dim_col Aw' = (last rs) ^ N_half" unfolding Aw'_def' Aw_def' using dims_A'_pow by auto lemma witness_tensor: assumes "is \ Tensor.dims Aw" shows "Tensor.lookup Aw is = (if nths is {n. even n} = nths is {n. odd n} \ (\i\set is. i < last (butlast rs)) then 1 else 0)" using assms deep no_zeros y_valid unfolding Aw_def proof (induction "butlast (butlast (butlast rs))" arbitrary:rs "is" y) case Nil have "length rs = 3" by (rule antisym, metis Nil.hyps One_nat_def Suc_1 Suc_eq_plus1 add_2_eq_Suc' diff_diff_left length_butlast less_numeral_extra(3) list.size(3) not_le numeral_3_eq_3 zero_less_diff, metis `3 \ length rs`) then have "rs = [rs!0, rs!1, rs!2]" by (metis (no_types, lifting) Cons_nth_drop_Suc One_nat_def Suc_eq_plus1 append_Nil id_take_nth_drop length_0_conv length_tl lessI list.sel(3) list.size(4) not_le numeral_3_eq_3 numeral_le_one_iff one_add_one semiring_norm(70) take_0 zero_less_Suc) have "input_sizes (witness_l [rs ! 0, rs ! 1, rs ! 2]) = [rs!2, rs!2]" using witness.simps witness'.simps input_sizes.simps by auto then have "Tensor.dims (tensors_from_net (witness_l rs) $ y) = [rs!2, rs!2]" using dims_tensors_from_net[of "tensors_from_net (witness_l rs) $ y" "witness_l rs"] Nil.prems(4) length_output_witness \rs = [rs ! 0, rs ! 1, rs ! 2]\ vec_setI by metis then have "is \ [rs!2, rs!2]" using Nil.prems by metis then have "Tensor.lookup ((tensors_from_net (witness_l rs))$y) is = (if is ! 0 = is ! 1 \ is ! 0 < rs ! 1 then 1 else 0)" using Nil.prems(4) \rs = [rs ! 0, rs ! 1, rs ! 2]\ by (metis list.sel(3) lookup_tensors_ht_l1) have "is ! 0 = is ! 1 \ is ! 0 < rs ! 1 \ nths is {n. even n} = nths is {n. odd n} \ (\i\set is. i < last (butlast rs))" proof - have "length is = 2" by (metis One_nat_def Suc_eq_plus1 \is \ [rs ! 2, rs ! 2]\ list.size(3) list.size(4) numeral_2_eq_2 valid_index_length) have "nths is {n. even n} = [is!0]" apply (rule nths_only_one) using subset_antisym less_2_cases `length is = 2` by fastforce have "nths is {n. odd n} = [is!1]" apply (rule nths_only_one) using subset_antisym less_2_cases `length is = 2` by fastforce have "last (butlast rs) = rs!1" by (metis One_nat_def Suc_eq_plus1 \rs = [rs ! 0, rs ! 1, rs ! 2]\ append_butlast_last_id last_conv_nth length_butlast length_tl lessI list.sel(3) list.simps(3) list.size(3) list.size(4) nat.simps(3) nth_append) show ?thesis unfolding `last (butlast rs) = rs!1` apply (rule iffI; rule conjI) apply (simp add: \nths is (Collect even) = [is ! 0]\ \nths is {n. odd n} = [is ! 1]\) apply (metis `length is = 2` One_nat_def in_set_conv_nth less_2_cases) apply (simp add: \nths is (Collect even) = [is ! 0]\ \nths is {n. odd n} = [is ! 1]\) apply (simp add: \length is = 2\) done qed then show ?case unfolding \Tensor.lookup (tensors_from_net (witness_l rs) $ y) is = (if is ! 0 = is ! 1 \ is ! 0 < rs ! 1 then 1 else 0)\ using witness_is_deep_model witness_valid \rs = [rs ! 0, rs ! 1, rs ! 2]\ by auto next case (Cons r rs' rs "is" j) text \We prove the Induction Hypothesis for "tl rs" and j=0:\ have "rs = r # tl rs" by (metis Cons.hyps(2) append_butlast_last_id butlast.simps(1) hd_append2 list.collapse list.discI list.sel(1)) have 1:"rs' = butlast (butlast (butlast (tl rs)))" by (metis Cons.hyps(2) butlast_tl list.sel(3)) have 2:"3 \ length (tl rs)" by (metis (no_types, lifting) Cons.hyps(2) Cons.prems(2) Nitpick.size_list_simp(2) One_nat_def Suc_eq_plus1 \rs = r # tl rs\ \rs' = butlast (butlast (butlast (tl rs)))\ diff_diff_left diff_self_eq_0 gr0_conv_Suc le_Suc_eq length_butlast length_tl less_numeral_extra(3) list.simps(3) numeral_3_eq_3) have 3:"\r. r \ set (tl rs) \ 0 < r" by (metis Cons.prems(3) list.sel(2) list.set_sel(2)) have 4:"0 < (tl rs) ! 0" using "2" "3" by auto have IH: "\is'. is' \ Tensor.dims (tensors_from_net (witness_l (tl rs)) $ 0) \ Tensor.lookup (tensors_from_net (witness_l (tl rs)) $ 0) is' = (if nths is' (Collect even) = nths is' {n. odd n} \ (\i\set is'. i < last (butlast (tl rs))) then 1 else 0)" using "1" "2" "3" 4 Cons.hyps(1) by blast text \The list "is" can be split in two parts:\ have "is \ replicate (2^(length rs - 2)) (last rs)" using Cons.prems(3) dims_output_witness 2 by (metis (no_types, lifting) Cons.prems(1) Cons.prems(3) Cons.prems(4) Nitpick.size_list_simp(2) One_nat_def diff_diff_left diff_is_0_eq length_tl nat_le_linear not_numeral_le_zero numeral_le_one_iff one_add_one semiring_norm(70)) then have "is \ replicate (2^(length (tl rs) - 2)) (last rs) @ replicate (2^(length (tl rs) - 2)) (last rs)" using Cons.prems dims_output_witness by (metis "2" Nitpick.size_list_simp(2) One_nat_def diff_diff_left length_tl mult_2 not_numeral_le_zero numeral_le_one_iff one_add_one power.simps(2) replicate_add semiring_norm(70)) then obtain is1 is2 where "is = is1 @ is2" and is1_replicate: "is1 \ replicate (2^(length (tl rs) - 2)) (last rs)" and is2_replicate: "is2 \ replicate (2^(length (tl rs) - 2)) (last rs)" by (metis valid_index_split) then have is1_valid: "is1 \ Tensor.dims (tensors_from_net (witness_l (tl rs)) $ 0)" (is ?is1) and is2_valid: "is2 \ Tensor.dims (tensors_from_net (witness_l (tl rs)) $ 0)" (is ?is2) proof - have "last (tl rs) = last rs" by (metis "2" \rs = r # tl rs\ last_ConsR list.size(3) not_numeral_le_zero) then show ?is1 ?is2 using dims_output_witness[of "tl rs"] using dims_output_witness[of "tl rs"] 2 3 is1_replicate is2_replicate \last (tl rs) = last rs\ by auto qed text \A shorthand for the condition to find a "1" in the tensor:\ let ?cond = "\is rs. nths is {n. even n} = nths is {n. odd n} \ (\i\set is. i < last (butlast rs))" text \We can use the IH on our newly created is1 and is2:\ have IH_is12: "Tensor.lookup (tensors_from_net (witness_l (tl rs)) $ 0) is1 = (if (?cond is1 (tl rs)) then 1 else 0)" "Tensor.lookup (tensors_from_net (witness_l (tl rs)) $ 0) is2 = (if (?cond is2 (tl rs)) then 1 else 0)" using IH is1_valid is2_valid by fast+ text \In the induction step we have to add two layers: first the Pool layer, then the Conv layer. The Pool layer connects the two subtrees. Therefore the two conditions on is1 and is2 become one, and we have to prove that they are equivalent:\ have "?cond is1 (tl rs) \ ?cond is2 (tl rs) \ ?cond is rs" proof - have "length is1 = 2 ^ (length (tl rs) - 2)" "length is2 = 2 ^ (length (tl rs) - 2)" using is1_replicate is2_replicate by (simp_all add: valid_index_length) then have "even (length is1)" "even (length is2)" by (metis Cons.hyps(2) One_nat_def add_gr_0 diff_diff_left even_numeral even_power length_butlast length_tl list.size(4) one_add_one zero_less_Suc)+ then have "{j. j + length is1 \ {n. even n}} = {n. even n}" "{j. j + length is1 \ {n. odd n}} = {n. odd n}" by simp_all have "length (nths is2 (Collect even)) = length (nths is2 (Collect odd))" using length_nths_even \even (length is2)\ by blast have cond1_iff: "(nths is1 (Collect even) = nths is1 {n. odd n} \ nths is2 (Collect even) = nths is2 {n. odd n}) = (nths is (Collect even) = nths is {n. odd n})" unfolding `is = is1 @ is2` nths_append `{j. j + length is1 \ {n. odd n}} = {n. odd n}` `{j. j + length is1 \ {n. even n}} = {n. even n}` by (simp add: \length (nths is2 (Collect even)) = length (nths is2 (Collect odd))\) have "last (butlast (tl rs)) = last (butlast rs)" using Nitpick.size_list_simp(2) \even (length is1)\ \length is1 = 2 ^ (length (tl rs) - 2)\ butlast_tl last_tl length_butlast length_tl not_less_eq zero_less_diff by (metis (full_types) Cons.hyps(2) length_Cons less_nat_zero_code) have cond2_iff: "(\i\set is1. i < last (butlast (tl rs))) \ (\i\set is2. i < last (butlast (tl rs))) \ (\i\set is. i < last (butlast rs))" unfolding `last (butlast (tl rs)) = last (butlast rs)` `is = is1 @ is2` set_append by blast then show ?thesis using cond1_iff cond2_iff by blast qed text \Now we can make the Pool layer step: \ have lookup_witness': "Tensor.lookup ((tensors_from_net (witness' (rs ! 1) (tl (tl rs)))) $ 0) is = (if ?cond is rs then 1 else 0)" proof - have lookup_prod: "Tensor.lookup ((tensors_from_net (witness_l (tl rs)) $ 0) \ (tensors_from_net (witness_l (tl rs))) $ 0) is = (if ?cond is rs then 1 else 0)" using `?cond is1 (tl rs) \ ?cond is2 (tl rs) \ ?cond is rs` unfolding `is = is1 @ is2` lookup_tensor_prod[OF is1_valid is2_valid] IH_is12 by auto have witness_l_tl: "witness_l (tl rs) = witness (rs ! 1) (rs ! 2) (tl (tl (tl rs)))" by (metis One_nat_def Suc_1 \rs = r # tl rs\ nth_Cons_Suc) have tl_tl:"(tl (tl rs)) = ((rs ! 2) # tl (tl (tl rs)))" proof - have "length (tl (tl rs)) \ 0" by (metis One_nat_def Suc_eq_plus1 diff_diff_left diff_is_0_eq length_tl not_less_eq_eq Cons.prems(2) numeral_3_eq_3) then have "tl (tl rs) \ []" by fastforce then show ?thesis by (metis list.exhaust_sel nth_Cons_0 nth_Cons_Suc numeral_2_eq_2 tl_Nil) qed have length_gt0:"dim_vec (tensors_from_net (witness (rs ! 1) (rs ! 2) (tl (tl (tl rs))))) > 0" using output_size_correct_tensors[of "witness (rs ! 1) (rs ! 2) (tl (tl (tl rs)))"] witness_is_deep_model[of "rs ! 1" "rs ! 2" "tl (tl (tl rs))"] valid_deep_model[of "rs ! 1" "rs ! 2" "tl (tl (tl rs))"] output_size.simps witness.simps by (metis "2" "3" One_nat_def \rs = r # tl rs\ deep_model.elims length_greater_0_conv list.size(3) not_numeral_le_zero nth_Cons_Suc nth_mem) then have "tensors_from_net (witness' (rs ! 1) ((rs ! 2) # tl (tl (tl rs)))) $ 0 = (tensors_from_net (witness_l (tl rs)) $ 0) \ (tensors_from_net (witness_l (tl rs)) $ 0)" unfolding witness'.simps tensors_from_net.simps witness_l_tl using index_component_mult by blast then show ?thesis using lookup_prod tl_tl by simp qed text \Then we can make the Conv layer step: \ show ?case proof - have "valid_net' (witness' (rs ! 1) (tl (tl rs)))" by (simp add: witness'_valid) have "output_size' (witness' (rs ! 1) (tl (tl rs))) = rs ! 1" by (metis "2" Nitpick.size_list_simp(2) diff_diff_left diff_is_0_eq hd_Cons_tl deep_model'.simps(2) deep_model.elims length_tl not_less_eq_eq numeral_2_eq_2 numeral_3_eq_3 one_add_one output_size.simps(2) output_size.simps(3) tl_Nil witness'_is_deep_model) have if_resolve:"(if length (tl (tl rs)) = 0 then eye_matrix else if length (tl (tl rs)) = 1 then all1_matrix else copy_first_matrix) = copy_first_matrix" by (metis "2" Cons.prems(2) Nitpick.size_list_simp(2) One_nat_def Suc_n_not_le_n not_numeral_le_zero numeral_3_eq_3) have "tensors_from_net (Conv (copy_first_matrix (rs ! 0) (rs ! 1)) (witness' (rs ! 1) (tl (tl rs)))) $ j = tensors_from_net (witness' (rs ! 1) (tl (tl rs))) $ 0" using tensors_from_net_Conv_copy_first[OF `valid_net' (witness' (rs ! 1) (tl (tl rs)))` `j < rs ! 0`, unfolded `output_size' (witness' (rs ! 1) (tl (tl rs))) = rs ! 1`] using "4" One_nat_def \rs = r # tl rs\ nth_Cons_Suc by metis then show ?thesis unfolding witness.simps if_resolve `output_size' (witness' (rs ! 1) (tl (tl rs))) = rs ! 1` using lookup_witness' \valid_net' (witness' (rs ! 1) (tl (tl rs)))\ hd_conv_nth output_size_correct_tensors by fastforce qed qed lemma witness_matricization: assumes "i < dim_row Aw'" and "j < dim_col Aw'" shows "Aw' $$ (i, j) = (if i=j \ (\i0\set (digit_encode (nths (Tensor.dims Aw) {n. even n}) i). i0 < last (butlast rs)) then 1 else 0)" proof - define "is" where "is = weave {n. even n} (digit_encode (nths (Tensor.dims Aw) {n. even n}) i) (digit_encode (nths (Tensor.dims Aw) {n. odd n}) j)" have lookup_eq: "Aw' $$ (i, j) = Tensor.lookup Aw is" using Aw'_def matricize_def dims_Aw'(1)[symmetric, unfolded A_def] dims_Aw'(2)[symmetric, unfolded A_def Collect_neg_eq] index_mat(1)[OF `i < dim_row Aw'` `j < dim_col Aw'`] is_def Collect_neg_eq case_prod_conv by (metis (no_types) Aw'_def Collect_neg_eq case_prod_conv is_def matricize_def) have "is \ Tensor.dims Aw" using is_def valid_index_weave A_def Collect_neg_eq assms digit_encode_valid_index dims_Aw' by metis have "even (order Aw)" unfolding Aw_def using assms dims_output_witness even_numeral le_eq_less_or_eq numeral_2_eq_2 numeral_3_eq_3 deep no_zeros y_valid by fastforce have nths_dimsAw: "nths (Tensor.dims Aw) (Collect even) = nths (Tensor.dims Aw) {n. odd n}" proof - have 0:"Tensor.dims (tensors_from_net (witness_l rs) $ y) = replicate (2 ^ (length rs - 2)) (last rs)" using dims_output_witness[OF _ no_zeros y_valid] using deep by linarith show ?thesis unfolding A_def using nths_replicate by (metis (no_types, lifting) "0" Aw_def \even (order Aw)\ length_replicate length_nths_even) qed have "i = j \ nths is (Collect even) = nths is {n. odd n}" proof have eq_lengths: "length (digit_encode (nths (Tensor.dims Aw) (Collect even)) i) = length (digit_encode (nths (Tensor.dims Aw) {n. odd n}) j)" unfolding length_digit_encode by (metis \even (order Aw)\ length_nths_even) then show "i = j \ nths is (Collect even) = nths is {n. odd n}" unfolding is_def using nths_weave[of "digit_encode (nths (Tensor.dims Aw) (Collect even)) i" "Collect even" "digit_encode (nths (Tensor.dims Aw) {n. odd n}) j", unfolded eq_lengths, unfolded Collect_neg_eq[symmetric] card_even mult_2[symmetric] card_odd] nths_dimsAw by simp show "nths is (Collect even) = nths is {n. odd n} \ i = j" unfolding is_def using nths_weave[of "digit_encode (nths (Tensor.dims Aw) (Collect even)) i" "Collect even" "digit_encode (nths (Tensor.dims Aw) {n. odd n}) j", unfolded eq_lengths, unfolded Collect_neg_eq[symmetric] card_even mult_2[symmetric] card_odd] using \nths (Tensor.dims Aw) (Collect even) = nths (Tensor.dims Aw) {n. odd n}\ deep no_zeros y_valid assms digit_decode_encode dims_Aw' by auto (metis digit_decode_encode_lt) qed have "i=j \ set (digit_encode (nths (Tensor.dims Aw) {n. even n}) i) = set is" unfolding is_def nths_dimsAw using set_weave[of "(digit_encode (nths (Tensor.dims Aw) {n. odd n}) j)" "Collect even" "(digit_encode (nths (Tensor.dims Aw) {n. odd n}) j)", unfolded mult_2[symmetric] card_even Collect_neg_eq[symmetric] card_odd] Un_absorb card_even card_odd mult_2 by blast then show ?thesis unfolding lookup_eq using witness_tensor[OF `is \ Tensor.dims Aw`] by (simp add: A_def \(i = j) = (nths is (Collect even) = nths is {n. odd n})\) qed definition "rows_with_1 = {i. (\i0\set (digit_encode (nths (Tensor.dims Aw) {n. even n}) i). i0 < last (butlast rs))}" lemma card_low_digits: assumes "m>0" "\d. d\set ds \ m \ d" shows "card {i. i (\i0\set (digit_encode ds i). i0 < m)} = m ^ (length ds)" using assms proof (induction ds) case Nil then show ?case using prod_list.Nil by simp next case (Cons d ds) define low_digits where "low_digits ds i \ i < prod_list ds \ (\i0\set (digit_encode ds i). i0 < m)" for ds i have "card {i. low_digits ds i} = m ^ (length ds)" unfolding low_digits_def by (simp add: Cons.IH Cons.prems(1) Cons.prems(2)) have "card {i. low_digits (d # ds) i} = card ({.. {i. low_digits ds i})" proof - define f where "f p = fst p + d * snd p" for p have "inj_on f ({.. {i. low_digits ds i})" proof (rule inj_onI) fix x y assume "x \ {.. {i. low_digits ds i}" "y \ {.. {i. low_digits ds i}" "f x = f y" then have "fst xf x = f y\ \f x mod d = fst x\ \fst y < d\ f_def by auto show "x = y" using \f x = f y\ \f x div d = snd x\ \f x mod d = fst x\ \f y div d = snd y\ \f y mod d = fst y\ prod_eqI by fastforce qed have "f ` ({.. {i. low_digits ds i}) = {i. low_digits (d # ds) i}" proof (rule subset_antisym; rule subsetI) fix x assume "x \ f ` ({.. {i. low_digits ds i})" then obtain i0 i1 where "x = i0 + d * i1" "i0 < m" "low_digits ds i1" using f_def by force then have "i0 {i. low_digits (d # ds) i}" unfolding low_digits_def proof (rule; rule conjI) have "i1 < prod_list ds" "\i0\set (digit_encode ds i1). i0 < m" using `low_digits ds i1` low_digits_def by auto show "x < prod_list (d # ds)" unfolding prod_list.Cons `x = i0 + d * i1` using `i0 0" by (metis \i0 < d\ gr_implies_not0) then have "(i0 + d * i1) div (d * prod_list ds) = 0" by (simp add: Divides.div_mult2_eq \i0 < d\ \i1 < prod_list ds\) then show "i0 + d * i1 < d * prod_list ds" by (metis (no_types) \i0 < d\ \i1 < prod_list ds\ div_eq_0_iff gr_implies_not0 no_zero_divisors) qed show "\i0\set (digit_encode (d # ds) x). i0 < m" using \\i0\set (digit_encode ds i1). i0 < m\ \i0 < d\ \i0 < m\ \x = i0 + d * i1\ by auto qed next fix x assume "x \ {i. low_digits (d # ds) i}" then have "x < prod_list (d # ds)" "\i0\set (digit_encode (d # ds) x). i0 < m" using low_digits_def by auto have "x mod d < m" using `\i0\set (digit_encode (d # ds) x). i0 < m`[unfolded digit_encode.simps] by simp have "x div d < prod_list ds" using `x < prod_list (d # ds)`[unfolded prod_list.Cons] by (metis div_eq_0_iff div_mult2_eq mult_0_right not_less0) have "\i0\set (digit_encode ds (x div d)). i0 < m" by (simp add: \\i0\set (digit_encode (d # ds) x). i0 < m\) have "f ((x mod d),(x div d)) = x" by (simp add: f_def) show "x \ f ` ({.. {i. low_digits ds i})" by (metis SigmaI \\i0\set (digit_encode ds (x div d)). i0 < m\ \f (x mod d, x div d) = x\ \x div d < prod_list ds\ \x mod d < m\ image_eqI lessThan_iff low_digits_def mem_Collect_eq) qed then have "bij_betw f ({.. {i. low_digits ds i}) {i. low_digits (d # ds) i}" by (simp add: \inj_on f ({.. {i. low_digits ds i})\ bij_betw_def) then show ?thesis by (simp add: bij_betw_same_card) qed then show ?case unfolding `card {i. low_digits ds i} = m ^ (length ds)` card_cartesian_product using low_digits_def by simp qed lemma card_rows_with_1: "card {i\rows_with_1. irows_with_1. i (\i0\set (digit_encode (nths (Tensor.dims Aw) (Collect even)) i). i0 < r)}" (is "?A = ?B") proof (rule subset_antisym; rule subsetI) fix i assume "i \ ?A" then have "i < dim_row Aw'" "\i0\set (digit_encode (nths (Tensor.dims Aw) {n. even n}) i). i0 < last (butlast rs)" using rows_with_1_def by auto then have "i < prod_list (nths (dims Aw) (Collect even))" using dims_Aw' by linarith then have "digit_encode (nths (dims Aw) (Collect even)) i \ nths (dims Aw) (Collect even)" using digit_encode_valid_index by auto have "\i0\set (digit_encode (nths (Tensor.dims Aw) {n. even n}) i). i0 < r" proof fix i0 assume 1:"i0 \ set (digit_encode (nths (dims Aw) (Collect even)) i)" then obtain k where "k < length (digit_encode (nths (dims Aw) (Collect even)) i)" "digit_encode (nths (dims Aw) (Collect even)) i ! k = i0" by (meson in_set_conv_nth) have "i0 < last (butlast rs)" using \\i0\set (digit_encode (nths (dims Aw) (Collect even)) i). i0 < last (butlast rs)\ 1 by blast have "set (nths (dims Aw) (Collect even)) \ {last rs}" unfolding dims_Aw using subset_eq by fastforce then have "nths (dims Aw) (Collect even) ! k = last rs" using \digit_encode (nths (dims Aw) (Collect even)) i \ nths (dims Aw) (Collect even)\ \k < length (digit_encode (nths (dims Aw) (Collect even)) i)\ nth_mem valid_index_length by auto then have "i0 < last rs" using valid_index_lt \digit_encode (nths (dims Aw) (Collect even)) i ! k = i0\ \digit_encode (nths (dims Aw) (Collect even)) i \ nths (dims Aw) (Collect even)\ \k < length (digit_encode (nths (dims Aw) (Collect even)) i)\ valid_index_length by fastforce then show "i0 < r" unfolding r_def by (simp add: \i0 < last (butlast rs)\) qed then show "i \ ?B" using \i < prod_list (nths (dims Aw) (Collect even))\ by blast next fix i assume "i\?B" then show "i\?A" by (simp add: dims_Aw' r_def rows_with_1_def) qed have 2:"\d. d \ set (nths (Tensor.dims Aw) (Collect even)) \ r \ d" proof - fix d assume "d \ set (nths (Tensor.dims Aw) (Collect even))" then have "d \ set (Tensor.dims Aw)" using in_set_nthsD by fast then have "d = last rs" using dims_Aw by simp then show "r \ d" by (simp add: r_def) qed have 3:"0 < r" unfolding r_def by (metis deep diff_diff_cancel diff_zero dual_order.trans in_set_butlastD last_in_set length_butlast list.size(3) min_def nat_le_linear no_zeros not_numeral_le_zero numeral_le_one_iff rel_simps(3)) have 4: "length (nths (Tensor.dims Aw) (Collect even)) = N_half" unfolding length_nths order_Aw using card_even[of N_half] by (metis (mono_tags, lifting) Collect_cong) then show ?thesis using card_low_digits[of "r" "nths (Tensor.dims Aw) (Collect even)"] 1 2 3 4 by metis qed lemma infinite_rows_with_1: "infinite rows_with_1" proof - define listpr where "listpr = prod_list (nths (Tensor.dims Aw) {n. even n})" have "\i. listpr dvd i \ i \ rows_with_1" proof - fix i assume dvd_i: "listpr dvd i" { fix i0::nat assume "i0\set (digit_encode (nths (Tensor.dims Aw) {n. even n}) i)" then have "i0=0" using digit_encode_0 dvd_i listpr_def by auto then have "i0 < last (butlast rs)" using deep no_zeros by (metis Nitpick.size_list_simp(2) One_nat_def Suc_le_lessD in_set_butlastD last_in_set length_butlast length_tl not_numeral_less_zero numeral_2_eq_2 numeral_3_eq_3 numeral_le_one_iff semiring_norm(70)) } then show "i\rows_with_1" by (simp add: rows_with_1_def) qed have 0:"Tensor.dims Aw = replicate (2 ^ (length rs - 2)) (last rs)" unfolding Aw_def using dims_output_witness[OF _ no_zeros y_valid] using deep by linarith then have "listpr > 0" unfolding listpr_def 0 by (metis "0" deep last_in_set length_greater_0_conv less_le_trans no_zeros dims_Aw'_pow(1) dims_Aw'(1) zero_less_numeral zero_less_power) then have "inj (( * ) listpr)" by (metis injI mult_left_cancel neq0_conv) then show ?thesis using `\i. listpr dvd i \ i \ rows_with_1` by (meson dvd_triv_left image_subset_iff infinite_iff_countable_subset) qed lemma witness_submatrix: "submatrix Aw' rows_with_1 rows_with_1 = 1\<^sub>m (r^N_half)" proof show "dim_row (submatrix Aw' rows_with_1 rows_with_1) = dim_row (1\<^sub>m (r ^ N_half))" unfolding index_one_mat(2) dim_submatrix(1) by (metis (full_types) set_le_in card_rows_with_1) show "dim_col (submatrix Aw' rows_with_1 rows_with_1) = dim_col (1\<^sub>m (r ^ N_half))" by (metis \dim_row (submatrix Aw' rows_with_1 rows_with_1) = dim_row (1\<^sub>m (r ^ N_half))\ dim_submatrix(1) dim_submatrix(2) index_one_mat(2) index_one_mat(3) dims_Aw'_pow) show "\i j. i < dim_row (1\<^sub>m (r ^ N_half)) \ j < dim_col (1\<^sub>m (r ^ N_half)) \ submatrix Aw' rows_with_1 rows_with_1 $$ (i, j) = 1\<^sub>m (r ^ N_half) $$ (i, j)" proof - fix i j assume "i < dim_row (1\<^sub>m (r ^ N_half))" "j < dim_col (1\<^sub>m (r ^ N_half))" then have "i < r ^ N_half" "j < r ^ N_half" by auto then have "i < card {i \ rows_with_1. i < dim_row Aw'}" "j < card {i \ rows_with_1. i < dim_col Aw'}" using card_rows_with_1 dims_Aw'_pow by auto then have "pick rows_with_1 i < dim_row Aw'" "pick rows_with_1 j < dim_col Aw'" using card_le_pick_inf[OF infinite_rows_with_1, of "dim_row Aw'" i] using card_le_pick_inf[OF infinite_rows_with_1, of "dim_col Aw'" j] by force+ have "\i0\set (digit_encode (nths (dims Aw) (Collect even)) (pick rows_with_1 i)). i0 < last (butlast rs)" using infinite_rows_with_1 pick_in_set_inf rows_with_1_def by auto then have "Aw' $$ (pick rows_with_1 i, pick rows_with_1 j) = (if pick rows_with_1 i = pick rows_with_1 j then 1 else 0)" using witness_matricization[OF `pick rows_with_1 i < dim_row Aw'` `pick rows_with_1 j < dim_col Aw'`] by simp then have "submatrix Aw' rows_with_1 rows_with_1 $$ (i, j) = (if pick rows_with_1 i = pick rows_with_1 j then 1 else 0)" using submatrix_index by (metis (no_types, lifting) \dim_col (submatrix Aw' rows_with_1 rows_with_1) = dim_col (1\<^sub>m (r ^ N_half))\ \dim_row (submatrix Aw' rows_with_1 rows_with_1) = dim_row (1\<^sub>m (r ^ N_half))\ \i < dim_row (1\<^sub>m (r ^ N_half))\ \j < r ^ N_half\ dim_submatrix(1) dim_submatrix(2) index_one_mat(3)) then have "submatrix Aw' rows_with_1 rows_with_1 $$ (i, j) = (if i = j then 1 else 0)" using pick_eq_iff_inf[OF infinite_rows_with_1] by auto then show "submatrix Aw' rows_with_1 rows_with_1 $$ (i, j) = 1\<^sub>m (r ^ N_half) $$ (i, j)" by (simp add: \i < r ^ N_half\ \j < r ^ N_half\) qed qed lemma witness_det: "det (submatrix Aw' rows_with_1 rows_with_1) \ 0" unfolding witness_submatrix by simp end (* Examples to show that the locales can be instantiated: *) interpretation example : deep_model_correct_params False "[10,10,10]" unfolding deep_model_correct_params_def by simp interpretation example : deep_model_correct_params_y False "[10,10,10]" 1 unfolding deep_model_correct_params_y_def deep_model_correct_params_y_axioms_def deep_model_correct_params_def by simp end