diff --git a/thys/Number_Theoretic_Transform/Butterfly.thy b/thys/Number_Theoretic_Transform/Butterfly.thy new file mode 100644 --- /dev/null +++ b/thys/Number_Theoretic_Transform/Butterfly.thy @@ -0,0 +1,1323 @@ +(* +Title: Butterfly Algorithm for Number Theoretic Transform +Author: Thomas Ammer +*) + +theory Butterfly + imports NTT "HOL-Library.Discrete" +begin + +text \\pagebreak\ + +section \Butterfly Algorithms\ +text \\label{Butterfly}\ + +text \\indent Several recursive algorithms for $FFT$ based on +the divide and conquer principle have been developed in order to speed up the transform. +A method for reducing complexity is the butterfly scheme. +In this formalization, we consider the butterfly algorithm by Cooley +and Tukey~\parencite{Good1997} adapted to the setting of \textit{NTT}. +\ + +text \\noindent We additionally assume that $n$ is power of two.\ + +locale butterfly = ntt + + fixes N + assumes n_two_pot: "n = 2^N" +begin + +subsection \Recursive Definition\ + +text \Let's recall the definition of a transformed vector element: +\begin{equation*} +\mathsf{NTT}(\vec{x})_i = \sum _{j = 0} ^{n-1} x_j \cdot \omega ^{i\cdot j} +\end{equation*} + +We assume $n = 2^N$ and obtain: + +\begin{align*} +\sum _{j = 0} ^{< 2^N} x_j \cdot \omega ^{i\cdot j} \\ &= +\sum _{j = 0} ^{< 2^{N-1}} x_{2j} \cdot \omega ^{i\cdot 2j} + + \sum _{j = 0} ^{< 2^{N-1}} x_{2j+1} \cdot \omega ^{i\cdot (2j+1)} \\ +& = \sum _{j = 0} ^{< 2^{N-1}} x_{2j} \cdot (\omega^2) ^{i\cdot j} + + \omega^i \cdot \sum _{j = 0} ^{< 2^{N-1}} x_{2j+1} \cdot (\omega^2) ^{i\cdot j}\\ +& = (\sum _{j = 0} ^{< 2^{N-2}} x_{4j} \cdot (\omega^4) ^{i\cdot j} + + \omega^i \cdot \sum _{j = 0} ^{< 2^{N-2}} x_{4j+2} \cdot (\omega^4) ^{i\cdot j}) \\ +& \hspace{1cm}+ \omega^i \cdot (\sum _{j = 0} ^{< 2^{N-2}} x_{4j+1} \cdot (\omega^4) ^{i\cdot j} + + \omega^i \cdot \sum _{j = 0} ^{< 2^{N-2}} x_{4j+3} \cdot (\omega^4) ^{i\cdot j}) \text{ etc.} +\end{align*} + +which gives us a recursive algorithm: + +\begin{itemize} +\item Compose vectors consisting of elements at even and odd indices respectively +\item Compute a transformation of these vectors recursively where the dimensions are halved. +\item Add results after scaling the second subresult by $\omega^i$ +\end{itemize} + +\ + +text \Now we give a functional definition of the analogue to $FFT$ adapted to finite fields. +A gentle introduction to $FFT$ can be found in~\parencite{10.5555/1614191}. +For the fast implementation of Number Theoretic Transform in particular, have a look at~\parencite{cryptoeprint:2016/504}.\ + +text \(The following lemma is needed to obtain an automated termination proof of $FNTT$.)\ +lemma FNTT_termination_aux [simp]: "length (filter P [0..Please note that we closely adhere to the textbook definition which just +talks about elements at even and odd indices. We model the informal definition by predefined functions, +since this seems to be more handy during proofs. +An algorithm splitting the elements smartly will be presented afterwards.\ + +fun FNTT::"('a mod_ring) list \ ('a mod_ring) list" where +"FNTT [] = []"| +"FNTT [a] = [a]"| +"FNTT nums = (let nn = length nums; + nums1 = [nums!i. i \ filter even [0.. filter odd [0.. x k. x*(\^( (n div nn) * k))) fntt2 [0..<(nn div 2)]); + sum2 = map2 (-) fntt1 (map2 ( \ x k. x*(\^( (n div nn) * k))) fntt2 [0..<(nn div 2)]) + in sum1@sum2)" + +lemmas [simp del] = FNTT_termination_aux + + +text \ +Finally, we want to prove correctness, i.e. $FNTT\; xs = NTT\;xs$. +Since we consider a recursive algorithm, some kind of induction is appropriate: +Assume the claim for $\frac{2^d}{2} = 2^{d-1}$ and prove it for $2^d$, where $2^d$ is the vector length. +This implies that we have to talk about \textit{NTT}s with respect to some powers of $\omega$. +In particular, we decide to annotate \textit{NTT} with a degree $degr$ +indicating the referred vector length. There is a correspondence to the current level $l$ of recursion: + +\begin{equation*} +degr = 2^{N-l} +\end{equation*} + +\ + +text \\noindent A generalized version of \textit{NTT} keeps track of all levels during recursion:\ + +definition "ntt_gen numbers degr i = (\j=0..<(length numbers). (numbers ! j) * \^((n div degr)*i*j)) " + +definition "NTT_gen degr numbers = map (ntt_gen numbers (degr)) [0..< length numbers]" + +text \Whenever generalized \textit{NTT} is applied to a list of full length, + then its actually equal to the defined \textit{NTT}.\ + +lemma NTT_gen_NTT_full_length: + assumes "length numbers =n" + shows "NTT_gen n numbers = NTT numbers" + unfolding NTT_gen_def ntt_gen_def NTT_def ntt_def + using assms by simp + +subsection \Arguments on Correctness\ +text \First some general lemmas on list operations.\ + +lemma length_even_filter: "length [f i . i <- (filter even [0.. i < length ys \ (map2 f xs ys) ! i = f (xs ! i) (ys ! i)" + by (induction xs arbitrary: ys i) auto + +lemma filter_last_not: "\ P x \ filter P (xs@[x]) = filter P xs" + by simp + +lemma filter_even_map: "filter even [0..<2*(x::nat)] = map ((*) (2::nat)) [0.. 2*x = l \ (filter even [0.. y. (2::nat)*y +1) [0.. 2*x = l \ (filter odd [0..\noindent Lemmas by using the assumption $n = 2^N$.\ + +text \\noindent ($-1$ denotes the additive inverse of $1$ in the finite field.)\ + +lemma n_min1_2: "n = 2 \ \ = -1" + using omega_properties(1) omega_properties(2) power2_eq_1_iff by blast + +lemma n_min1_gr2: + assumes "n > 2" + shows "\^(n div 2) = -1" +proof- + have "\^(n div 2) \ -1 \ False" + proof- + assume "\^(n div 2) \ -1" + hence False + proof(cases "\^(n div 2) = 1") +case True + then show ?thesis using omega_properties(3) + by (metis Euclidean_Division.div_eq_0_iff div_less_dividend leD less_le_trans n_lst2 one_less_numeral_iff pos2 semiring_norm(76) zero_neq_numeral) +next + case False + hence "(\^(n div 2)) ^ (2::nat) \ 1" + by (smt (verit, ccfv_threshold) n_two_pot One_nat_def \\ ^ (n div 2) \ - 1\ diff_zero leD n_lst2 not_less_eq omega_properties(1) one_less_numeral_iff one_power2 power2_eq_square power_mult power_one_right power_strict_increasing_iff semiring_norm(76) square_eq_iff two_powr_div two_powrs_div) + moreover have "(n div 2) * 2 = n" using n_two_pot n_lst2 + by (metis One_nat_def Suc_lessD assms div_by_Suc_0 one_less_numeral_iff power_0 power_one_right power_strict_increasing_iff semiring_norm(76) two_powrs_div) + ultimately show ?thesis using omega_properties(1) + by (metis power_mult) +qed + thus False by simp +qed + then show ?thesis by auto +qed + +lemma div_exp_sub: "2^l < n \ n div (2^l) = 2^(N-l)"using n_two_pot + by (smt (z3) One_nat_def diff_is_0_eq diff_le_diff_pow div_if div_le_dividend eq_imp_le le_0_eq le_Suc_eq n_lst2 nat_less_le not_less_eq_eq numeral_2_eq_2 power_0 two_powr_div) + +lemma omega_div_exp_min1: + assumes "2^(Suc l) \ n" + shows "(\ ^(n div 2^(Suc l)))^(2^l) = -1" +proof- + have "(\ ^(n div 2^(Suc l)))^(2^l) = \ ^((n div 2^(Suc l))*2^l)" + by (simp add: power_mult) + moreover have "(n div 2^(Suc l)) = 2^(N - Suc l)" using assms div_exp_sub + by (metis n_two_pot eq_imp_le le_neq_implies_less one_less_numeral_iff power_diff power_inject_exp semiring_norm(76) zero_neq_numeral) + moreover have "N \ Suc l" using assms n_two_pot + by (metis diff_is_0_eq diff_le_diff_pow gr0I leD le_refl) + moreover hence "(2::nat)^(N - Suc l)*2^l = 2^(N- 1)" + by (metis Nat.add_diff_assoc diff_Suc_1 diff_diff_cancel diff_le_self le_add1 le_add_diff_inverse plus_1_eq_Suc power_add) + ultimately show ?thesis + by (metis n_two_pot One_nat_def \n div 2 ^ Suc l = 2 ^ (N - Suc l)\ diff_Suc_1 div_exp_sub n_lst2 n_min1_2 n_min1_gr2 nat_less_le nat_power_eq_Suc_0_iff one_less_numeral_iff power_inject_exp power_one_right semiring_norm(76)) +qed + +lemma omg_n_2_min1: "\^(n div 2) = -1" + by (metis n_lst2 n_min1_2 n_min1_gr2 nat_less_le numeral_Bit0_div_2 numerals(1) power_one_right) + +lemma neg_cong: "-(x::('a mod_ring)) = - y \ x = y" by simp + +text \Generalized \textit{NTT} indeed describes all recursive levels, +and thus, it is actually equivalent to the ordinary \textit{NTT} definition.\ + +theorem FNTT_NTT_gen_eq: "length numbers = 2^l \ 2^l \ n \ FNTT numbers = NTT_gen (length numbers) numbers" +proof(induction l arbitrary: numbers) + case 0 + then show ?case unfolding NTT_gen_def ntt_gen_def + by (auto simp: length_Suc_conv) +next + case (Suc l) + text \We define some lists that are used during the recursive call.\ + define numbers1 where "numbers1 = [numbers!i . i <- (filter even [0.. x k. x*(\^( (n div (length numbers)) * k))) + fntt2 [0..<((length numbers) div 2)])" + define sum2 where + "sum2 = map2 (-) fntt1 (map2 ( \ x k. x*(\^( (n div (length numbers)) * k))) + fntt2 [0..<((length numbers) div 2)])" + define l1 where "l1 = length numbers1" + define l2 where "l2 = length numbers2" + define llen where "llen = length numbers" + + text \Properties of those lists.\ + have numbers1_even: "length numbers1 = 2^l" + using numbers1_def length_even_filter Suc by simp + have numbers2_even: "length numbers2 = 2^l" + using numbers2_def length_odd_filter Suc by simp + have numbers1_fntt: "fntt1 = NTT_gen (2^l) numbers1" + using fntt1_def Suc.IH[of numbers1] numbers1_even Suc(3) by simp + hence fntt1_by_index: "fntt1 ! i = ntt_gen numbers1 (2^l) i" if "i < 2^l" for i + unfolding NTT_gen_def by (simp add: numbers1_even that) + have numbers2_fntt: "fntt2 = NTT_gen (2^l) numbers2" + using fntt2_def Suc.IH[of numbers2] numbers2_even Suc(3) by simp + hence fntt2_by_index: "fntt2 ! i = ntt_gen numbers2 (2^l) i" if "i < 2^l" for i + unfolding NTT_gen_def + by (simp add: numbers2_even that) + have fntt1_length: "length fntt1 = 2^l" unfolding numbers1_fntt NTT_gen_def numbers1_def + using numbers1_def numbers1_even by force + have fntt2_length: "length fntt2 = 2^l" unfolding numbers2_fntt NTT_gen_def numbers2_def + using numbers2_def numbers2_even by force + + text \We show that the list resulting from $FNTT$ is equal to the $NTT$ list. + First, we prove $FNTT$ and $NTT$ to be equal concerning their first halves.\ + have before_half: "map (ntt_gen numbers llen) [0..<(llen div 2)] = sum1" + proof- + + text \Length is important, since we want to use list lemmas later on.\ + have 00:"length (map (ntt_gen numbers llen) [0..<(llen div 2)]) = length sum1" + unfolding sum1_def llen_def + using Suc(2) map2_length[of _ fntt2 "[0..x y. x * \ ^ (n div length numbers * y)) fntt2 [0..We show equality by extensionality w.r.t. indices.\ + have 02:"(map (ntt_gen numbers llen) [0..<(llen div 2)]) ! i = sum1 ! i" + if "i < 2^l" for i + proof- + text \First simplify this term.\ + have 000:"(map (ntt_gen numbers llen) [0..<(llen div 2)]) ! i = + ntt_gen numbers llen i" + using "00" "01" that by auto + + text \Expand the definition of $sum1$ and massage the result.\ + moreover have 001:"sum1 ! i = (fntt1!i) + (fntt2!i) * (\^((n div llen) * i))" + unfolding sum1_def using map2_index + "00" "01" NTT_gen_def add.left_neutral diff_zero fntt1_length length_map length_upt map2_map_map map_nth nth_upt numbers2_even numbers2_fntt that llen_def by force + moreover have 002:"(fntt1!i) = (\j=0..^((n div (2^l))*i*j))" + unfolding l1_def + using fntt1_by_index[of i] that unfolding ntt_gen_def by simp + have 003:"... = (\j=0..^((n div llen)*i*(2*j)))" + apply (rule sum_rules(2)) + subgoal for j unfolding numbers1_def + apply(subst llen_def[symmetric]) + proof- + assume ass: "j < l1 " + hence "map ((!) numbers) (filter even [0.. ^ (n div 2 ^ l * i * j) = + numbers ! (2 * j) * \ ^ (n div llen * i * (2 * j))" + unfolding llen_def l1_def l2_def by (metis (mono_tags, lifting) mult.assoc mult.left_commute) + qed + done + moreover have 004: + "(fntt2!i) * (\^((n div llen) * i)) = + (\j=0..^((n div (2^l))*i*j+ (n div llen) * i))" + apply(rule trans[where s = "(\j = 0.. ^ (n div 2 ^ l * i * j) * \ ^ (n div llen * i))"]) + subgoal + unfolding l2_def llen_def + using fntt2_by_index[of i] that sum_in[of _ "(\^((n div llen) * i))" "l2"] comm_semiring_1_class.semiring_normalization_rules(26)[of \] + unfolding ntt_gen_def + using sum_rules apply presburger + done + apply (rule sum_rules(2)) + subgoal for j + using fntt2_by_index[of i] that sum_in[of _ "(\^((n div llen) * i))" "l2"] comm_semiring_1_class.semiring_normalization_rules(26)[of \] + unfolding ntt_gen_def + apply auto + done + done + have 005: "\ = (\j=0..^((n div llen)*i*(2*j+1))))" + apply (rule sum_rules(2)) + subgoal for j unfolding numbers2_def + apply(subst llen_def[symmetric]) + proof- + assume ass: "j < l2 " + hence "map ((!) numbers) (filter odd [0.. ^ (n div 2 ^ l * i * j + n div llen * i) + = numbers ! (2 * j + 1) * \ ^ (n div llen * i * (2 * j + 1))" unfolding llen_def + by (smt (z3) Groups.mult_ac(2) distrib_left mult.right_neutral mult_2 mult_cancel_left) + qed + done + then show ?thesis + using 000 001 002 003 004 005 + unfolding sum1_def llen_def l1_def l2_def + using sum_splice_other_way_round[of "\ d. numbers ! d * \ ^ (n div length numbers * i * d)" "2^l"] Suc(2) + unfolding ntt_gen_def + by (smt (z3) Groups.mult_ac(2) numbers1_even numbers2_even power_Suc2) + qed + then show ?thesis + by (metis "00" "01" nth_equalityI) + qed + + text \We show equality for the indices in the second halves.\ + have after_half: "map (ntt_gen numbers llen) [(llen div 2)..Equality for every index.\ + have 02:"(map (ntt_gen numbers llen) [(llen div 2)..x y. x * \ ^ (n div llen * y)) fntt2 [0.. ^ (n div llen * i)" + using Suc(2) that by (simp add: fntt2_length llen_def) + have 003: "- fntt2 ! i * \ ^ (n div llen * i) = + fntt2 ! i * \ ^ (n div llen * (i+ llen div 2))" + using Suc(2) omega_div_exp_min1[of l] unfolding llen_def + by (smt (z3) Suc.prems(2) mult.commute mult.left_commute mult_1s_ring_1(2) neq0_conv nonzero_mult_div_cancel_left numeral_One pos2 power_Suc power_add power_mult) + hence 004:"sum2 ! i = (fntt1!i) - (fntt2!i) * (\^((n div llen) * i))" + unfolding sum2_def llen_def + by (simp add: Suc.prems(1) fntt1_length fntt2_length that) + have 005:"(fntt1!i) = + (\j=0..^((n div (2^l))*i*j))" + using fntt1_by_index that unfolding ntt_gen_def l1_def by simp + have 006:"\ =(\j=0..^((n div llen)*i*(2*j)))" + apply (rule sum_rules(2)) + subgoal for j unfolding numbers1_def + apply(subst llen_def[symmetric]) + proof- + assume ass: "j < l1 " + hence "map ((!) numbers) (filter even [0.. ^ (n div 2 ^ l * i * j) = + numbers ! (2 * j) * \ ^ (n div llen * i * (2 * j))" + by (metis (mono_tags, lifting) mult.assoc mult.left_commute) + qed + done + have 007:"\ = (\j=0..^((n div llen)*(2^l + i)*(2*j))) " + apply (rule sum_rules(2)) + subgoal for j + using Suc(2) Suc(3) omega_div_exp_min1[of l] llen_def l1_def numbers1_def + apply(smt (verit, del_insts) add.commute minus_power_mult_self mult_2 mult_minus1_right power_add power_mult) + done + done + moreover have 008: "(fntt2!i) * (\^((n div llen) * i)) = + (\j=0..^((n div (2^l))*i*j+ (n div llen) * i))" + apply(rule trans[where s = "(\j = 0.. ^ (n div 2 ^ l * i * j) * \ ^ (n div llen * i))"]) + subgoal + using fntt2_by_index[of i] that sum_in comm_semiring_1_class.semiring_normalization_rules(26)[of \] + unfolding ntt_gen_def + using sum_rules l2_def apply presburger + done + apply (rule sum_rules(2)) + subgoal for j + using fntt2_by_index[of i] that sum_in comm_semiring_1_class.semiring_normalization_rules(26)[of \] + unfolding ntt_gen_def + apply auto + done + done + have 009: "\ = (\j=0..^((n div llen)*i*(2*j+1))))" + apply (rule sum_rules(2)) + subgoal for j unfolding numbers2_def + apply(subst llen_def[symmetric]) + proof- + assume ass: "j < l2 " + hence "map ((!) numbers) (filter odd [0.. ^ (n div 2 ^ l * i * j + n div llen * i) + = numbers ! (2 * j + 1) * \ ^ (n div llen * i * (2 * j + 1))" + by (smt (z3) Groups.mult_ac(2) distrib_left mult.right_neutral mult_2 mult_cancel_left) + qed + done + have 010: " (fntt2!i) * (\^((n div llen) * i)) = (\j=0..^((n div llen)*i*(2*j+1)))) " + using 008 009 by presburger + have 011: " - (fntt2!i) * (\^((n div llen) * i)) = + (\j=0..^((n div llen)*i*(2*j+1)))) " + apply(rule neg_cong) + apply(rule trans[of _ "fntt2 ! i * \ ^ (n div llen * i)"]) + subgoal by simp + apply(rule trans[where s="(\j=0..^((n div llen)*i*(2*j+1))))"]) + subgoal using 008 009 by simp + apply(rule sym) + using sum_neg_in[of _ "l2"] + apply simp + done + have 012: "\ = (\j=0..^((n div llen)*(2^l+i)*(2*j+1))))" + apply(rule sum_rules(2)) + subgoal for j + using Suc(2) Suc(3) omega_div_exp_min1[of l] llen_def l2_def + apply (smt (z3) add.commute exp_rule mult.assoc mult_minus1_right plus_1_eq_Suc power_add power_minus1_odd power_mult) + done + done + have 013:"fntt1 ! i = (\j = 0..<2 ^ l. numbers!(2*j) * \ ^ (n div llen * (2^l + i) * (2*j)))" + using 005 006 007 numbers1_even llen_def l1_def by auto + have 014: "(\j = 0..<2 ^ l. numbers ! (2*j + 1) * \ ^ (n div llen* (2^l + i) * (2*j + 1))) = + - fntt2 ! i * \ ^ (n div llen * i)" + using trans[OF l2_def numbers2_even] sym[OF 012] sym[OF 011] by simp + have "ntt_gen numbers llen (2 ^ l + i) = (fntt1!i) - (fntt2!i) * (\^((n div llen) * i))" + unfolding ntt_gen_def apply(subst Suc(2)) + using sum_splice[of "\ d. numbers ! d * \ ^ (n div llen * (2^l+i) * d)" "2^l"] sym[OF 013] 014 Suc(2) by simp + thus ?thesis using 000 sym[OF 001] "004" sum2_def by simp + qed + then show ?thesis + by (metis "00" "01" list_eq_iff_nth_eq) + qed + obtain x y xs where xyxs: "numbers = x#y#xs" using Suc(2) + by (metis FNTT.cases add.left_neutral even_Suc even_add length_Cons list.size(3) mult_2 power_Suc power_eq_0_iff zero_neq_numeral) + show ?case + apply(subst xyxs) + apply(subst FNTT.simps(3)) + apply(subst xyxs[symmetric])+ + unfolding Let_def + using map_append[of "ntt_gen numbers llen" " [0..\noindent \textbf{Major Correctness Theorem for Butterfly Algorithm}.\\ + +We have already shown: +\begin{itemize} +\item Generalized $NTT$ with degree annotation $2^N$ equals usual $NTT$. +\item Generalized $NTT$ tracks all levels of recursion in $FNTT$. +\end{itemize} +Thus, $FNTT$ equals $NTT$. +\ + +theorem FNTT_correct: + assumes "length numbers = n" + shows "FNTT numbers = NTT numbers" + using FNTT_NTT_gen_eq NTT_gen_NTT_full_length assms n_two_pot by force + +subsection \Inverse Transform in Butterfly Scheme\ + +text \We also formalized the inverse transform by using the butterfly scheme. +Proofs are obtained by adaption of arguments for $FNTT$.\ + + +lemmas [simp] = FNTT_termination_aux + +fun IFNTT where +"IFNTT [] = []"| +"IFNTT [a] = [a]"| +"IFNTT nums = (let nn = length nums; + nums1 = [nums!i . i <- (filter even [0.. x k. x*(\^( (n div nn) * k))) ifntt2 [0..<(nn div 2)]); + sum2 = map2 (-) ifntt1 (map2 ( \ x k. x*(\^( (n div nn) * k))) ifntt2 [0..<(nn div 2)]) + in sum1@sum2)" + +lemmas [simp del] = FNTT_termination_aux + + +definition "intt_gen numbers degr i = (\j=0..<(length numbers). (numbers ! j) * \ ^((n div degr)*i*j)) " + +definition "INTT_gen degr numbers = map (intt_gen numbers (degr)) [0..< length numbers]" + +lemma INTT_gen_INTT_full_length: + assumes "length numbers =n" + shows "INTT_gen n numbers = INTT numbers" + unfolding INTT_gen_def intt_gen_def INTT_def intt_def + using assms by simp + +lemma my_div_exp_min1: + assumes "2^(Suc l) \ n" + shows "(\ ^(n div 2^(Suc l)))^(2^l) = -1" + by (metis assms divide_minus1 mult_zero_right mu_properties(1) nonzero_mult_div_cancel_right omega_div_exp_min1 power_one_over zero_neq_one) + +lemma my_n_2_min1: "\^(n div 2) = -1" + by (metis divide_minus1 mult_zero_right mu_properties(1) nonzero_mult_div_cancel_right omg_n_2_min1 power_one_over zero_neq_one) + +text \Correctness proof by common induction technique. Same strategies as for $FNTT$.\ + +theorem IFNTT_INTT_gen_eq: + "length numbers = 2^l \ 2^l \ n \ IFNTT numbers = INTT_gen (length numbers) numbers" +proof(induction l arbitrary: numbers) + case 0 + hence "local.IFNTT numbers = [numbers ! 0]" + by (metis IFNTT.simps(2) One_nat_def Suc_length_conv length_0_conv nth_Cons_0 power_0) + then show ?case unfolding INTT_gen_def intt_gen_def + using 0 by simp +next + case (Suc l) + text \We define some lists that are used during the recursive call.\ + define numbers1 where "numbers1 = [numbers!i . i <- (filter even [0.. x k. x*(\^( (n div (length numbers)) * k))) + ifntt2 [0..<((length numbers) div 2)])" + define sum2 where + "sum2 = map2 (-) ifntt1 (map2 ( \ x k. x*(\^( (n div (length numbers)) * k))) + ifntt2 [0..<((length numbers) div 2)])" + define l1 where "l1 = length numbers1" + define l2 where "l2 = length numbers2" + define llen where "llen = length numbers" + + text \Properties of those lists\ + have numbers1_even: "length numbers1 = 2^l" + using numbers1_def length_even_filter Suc by simp + have numbers2_even: "length numbers2 = 2^l" + using numbers2_def length_odd_filter Suc by simp + have numbers1_ifntt: "ifntt1 = INTT_gen (2^l) numbers1" + using ifntt1_def Suc.IH[of numbers1] numbers1_even Suc(3) by simp + hence ifntt1_by_index: "ifntt1 ! i = intt_gen numbers1 (2^l) i" if "i < 2^l" for i + unfolding INTT_gen_def by (simp add: numbers1_even that) + have numbers2_ifntt: "ifntt2 = INTT_gen (2^l) numbers2" + using ifntt2_def Suc.IH[of numbers2] numbers2_even Suc(3) by simp + hence ifntt2_by_index: "ifntt2 ! i = intt_gen numbers2 (2^l) i" if "i < 2^l" for i + unfolding INTT_gen_def by (simp add: numbers2_even that) + have ifntt1_length: "length ifntt1 = 2^l" unfolding numbers1_ifntt INTT_gen_def numbers1_def + using numbers1_def numbers1_even by force + have ifntt2_length: "length ifntt2 = 2^l" unfolding numbers2_ifntt INTT_gen_def numbers2_def + using numbers2_def numbers2_even by force + + text \Same proof structure as for the \textit{FNTT} proof. + $\omega$s are just replaced by $\mu$s.\ + have before_half: "map (intt_gen numbers llen) [0..<(llen div 2)] = sum1" + proof- + + text \Length is important, since we want to use list lemmas later on.\ + have 00:"length (map (intt_gen numbers llen) [0..<(llen div 2)]) = length sum1" + unfolding sum1_def llen_def + using Suc(2) map2_length[of _ ifntt2 "[0..x y. x * \ ^ (n div length numbers * y)) ifntt2 [0..We show equality by extensionality on indices.\ + have 02:"(map (intt_gen numbers llen) [0..<(llen div 2)]) ! i = sum1 ! i" + if "i < 2^l" for i + proof- + text \First simplify this term.\ + have 000:"(map (intt_gen numbers llen) [0..<(llen div 2)]) ! i = intt_gen numbers llen i" + using "00" "01" that by auto + + text \Expand the definition of $sum1$ and massage the result.\ + moreover have 001:"sum1 ! i = (ifntt1!i) + (ifntt2!i) * (\^((n div llen) * i))" + unfolding sum1_def using map2_index + "00" "01" INTT_gen_def add.left_neutral diff_zero ifntt1_length length_map length_upt map2_map_map map_nth nth_upt numbers2_even numbers2_ifntt that llen_def by force + moreover have 002:"(ifntt1!i) = (\j=0..^((n div (2^l))*i*j))" + unfolding l1_def + using ifntt1_by_index[of i] that unfolding intt_gen_def by simp + have 003:"... = (\j=0..^((n div llen)*i*(2*j)))" + apply (rule sum_rules(2)) + subgoal for j unfolding numbers1_def + apply(subst llen_def[symmetric]) + proof- + assume ass: "j < l1 " + hence "map ((!) numbers) (filter even [0.. ^ (n div 2 ^ l * i * j) = + numbers ! (2 * j) * \ ^ (n div llen * i * (2 * j))" + unfolding llen_def l1_def l2_def by (metis (mono_tags, lifting) mult.assoc mult.left_commute) + qed + done + moreover have 004: + "(ifntt2!i) * (\^((n div llen) * i)) = + (\j=0..^((n div (2^l))*i*j+ (n div llen) * i))" + apply(rule trans[where s = "(\j = 0.. ^ (n div 2 ^ l * i * j) * \ ^ (n div llen * i))"]) + subgoal + unfolding l2_def llen_def + using ifntt2_by_index[of i] that sum_in[of _ "(\^((n div llen) * i))" "l2"] comm_semiring_1_class.semiring_normalization_rules(26)[of \] + unfolding intt_gen_def + using sum_rules apply presburger + done + apply (rule sum_rules(2)) + subgoal for j + using ifntt2_by_index[of i] that sum_in[of _ "(\^((n div llen) * i))" "l2"] comm_semiring_1_class.semiring_normalization_rules(26)[of \] + unfolding intt_gen_def + apply auto + done + done + have 005: "\ = (\j=0..^((n div llen)*i*(2*j+1))))" + apply (rule sum_rules(2)) + subgoal for j unfolding numbers2_def + apply(subst llen_def[symmetric]) + proof- + assume ass: "j < l2 " + hence "map ((!) numbers) (filter odd [0.. ^ (n div 2 ^ l * i * j + n div llen * i) + = numbers ! (2 * j + 1) * \ ^ (n div llen * i * (2 * j + 1))" unfolding llen_def + by (smt (z3) Groups.mult_ac(2) distrib_left mult.right_neutral mult_2 mult_cancel_left) + qed + done + then show ?thesis + using 000 001 002 003 004 005 + unfolding sum1_def llen_def l1_def l2_def + using sum_splice_other_way_round[of "\ d. numbers ! d * \ ^ (n div length numbers * i * d)" "2^l"] Suc(2) + unfolding intt_gen_def + by (smt (z3) Groups.mult_ac(2) numbers1_even numbers2_even power_Suc2) + qed + then show ?thesis + by (metis "00" "01" nth_equalityI) + qed + + text \We show index-wise equality for the second halves\ + have after_half: "map (intt_gen numbers llen) [(llen div 2)..Equality for every index\ + have 02:"(map (intt_gen numbers llen) [(llen div 2)..x y. x * \ ^ (n div llen * y)) ifntt2 [0.. ^ (n div llen * i)" + using Suc(2) that by (simp add: ifntt2_length llen_def) + have 003: "- ifntt2 ! i * \ ^ (n div llen * i) = ifntt2 ! i * \ ^ (n div llen * (i+ llen div 2))" + using Suc(2) my_div_exp_min1[of l] unfolding llen_def + by (smt (z3) Suc.prems(2) mult.commute mult.left_commute mult_1s_ring_1(2) neq0_conv nonzero_mult_div_cancel_left numeral_One pos2 power_Suc power_add power_mult) + hence 004:"sum2 ! i = (ifntt1!i) - (ifntt2!i) * (\^((n div llen) * i))" + unfolding sum2_def llen_def + by (simp add: Suc.prems(1) ifntt1_length ifntt2_length that) + have 005:"(ifntt1!i) = + (\j=0..^((n div (2^l))*i*j))" + using ifntt1_by_index that unfolding intt_gen_def l1_def by simp + have 006:"\ =(\j=0..^((n div llen)*i*(2*j)))" + apply (rule sum_rules(2)) + subgoal for j unfolding numbers1_def + apply(subst llen_def[symmetric]) + proof- + assume ass: "j < l1 " + hence "map ((!) numbers) (filter even [0.. ^ (n div 2 ^ l * i * j) = + numbers ! (2 * j) * \ ^ (n div llen * i * (2 * j))" + by (metis (mono_tags, lifting) mult.assoc mult.left_commute) + qed + done + have 007:"\ = (\j=0.. ^((n div llen)*(2^l + i)*(2*j))) " + apply (rule sum_rules(2)) + subgoal for j + using Suc(2) Suc(3) my_div_exp_min1[of l] llen_def l1_def numbers1_def + apply(smt (verit, del_insts) add.commute minus_power_mult_self mult_2 mult_minus1_right power_add power_mult) + done + done + moreover have 008: "(ifntt2!i) * (\^((n div llen) * i)) = + (\j=0..^((n div (2^l))*i*j+ (n div llen) * i))" + apply(rule trans[where s = "(\j = 0.. ^ (n div 2 ^ l * i * j) * \ ^ (n div llen * i))"]) + subgoal + using ifntt2_by_index[of i] that sum_in comm_semiring_1_class.semiring_normalization_rules(26)[of \] + unfolding intt_gen_def + using sum_rules l2_def apply presburger + done + apply (rule sum_rules(2)) + subgoal for j + using ifntt2_by_index[of i] that sum_in comm_semiring_1_class.semiring_normalization_rules(26)[of \] + unfolding intt_gen_def + apply auto + done + done + have 009: "\ = (\j=0..^((n div llen)*i*(2*j+1))))" + apply (rule sum_rules(2)) + subgoal for j unfolding numbers2_def + apply(subst llen_def[symmetric]) + proof- + assume ass: "j < l2 " + hence "map ((!) numbers) (filter odd [0.. ^ (n div 2 ^ l * i * j + n div llen * i) + = numbers ! (2 * j + 1) * \ ^ (n div llen * i * (2 * j + 1))" + by (smt (z3) Groups.mult_ac(2) distrib_left mult.right_neutral mult_2 mult_cancel_left) + qed + done + have 010: " (ifntt2!i) * (\^((n div llen) * i)) = (\j=0..^((n div llen)*i*(2*j+1)))) " + using 008 009 by presburger + have 011: " - (ifntt2!i) * (\^((n div llen) * i)) = + (\j=0..^((n div llen)*i*(2*j+1)))) " + apply(rule neg_cong) + apply(rule trans[where s="(\j=0..^((n div llen)*i*(2*j+1))))"]) + subgoal using 008 009 by simp + apply(rule sym) + using sum_neg_in[of _ "l2"] + apply simp + done + have 012: "\ = (\j=0..^((n div llen)*(2^l+i)*(2*j+1))))" + apply(rule sum_rules(2)) + subgoal for j + using Suc(2) Suc(3) my_div_exp_min1[of l] llen_def l2_def + apply (smt (z3) add.commute exp_rule mult.assoc mult_minus1_right plus_1_eq_Suc power_add power_minus1_odd power_mult) + done + done + have 013:"ifntt1 ! i = (\j = 0..<2 ^ l. numbers!(2*j) * \ ^ (n div llen * (2^l + i) * (2*j)))" + using 005 006 007 numbers1_even llen_def l1_def by auto + have 014: "(\j = 0..<2 ^ l. numbers ! (2*j + 1) * \ ^ (n div llen* (2^l + i) * (2*j + 1))) = + - ifntt2 ! i * \ ^ (n div llen * i)" + using trans[OF l2_def numbers2_even] sym[OF 012] sym[OF 011] by simp + have "intt_gen numbers llen (2 ^ l + i) = (ifntt1!i) - (ifntt2!i) * (\^((n div llen) * i))" + unfolding intt_gen_def + apply(subst Suc(2)) + using sum_splice[of "\ d. numbers ! d * \ ^ (n div llen * (2^l+i) * d)" "2^l"] sym[OF 013] 014 Suc(2) by simp + thus ?thesis using 000 sym[OF 001] "004" sum2_def by simp + qed + then show ?thesis + by (metis "00" "01" list_eq_iff_nth_eq) + qed + obtain x y xs where xyxs: "numbers = x#y#xs" using Suc(2) + by (metis FNTT.cases add.left_neutral even_Suc even_add length_Cons list.size(3) mult_2 power_Suc power_eq_0_iff zero_neq_numeral) + show ?case + apply(subst xyxs) + apply(subst IFNTT.simps(3)) + apply(subst xyxs[symmetric])+ + unfolding Let_def + using map_append[of "intt_gen numbers llen" " [0..Correctness of the butterfly scheme for the inverse \textit{INTT}.\ + +theorem IFNTT_correct: + assumes "length numbers = n" + shows "IFNTT numbers = INTT numbers" + using IFNTT_INTT_gen_eq INTT_gen_INTT_full_length assms n_two_pot by force + +text \Also $FNTT$ and $IFNTT$ are mutually inverse\ + +theorem IFNTT_inv_FNTT: + assumes "length numbers = n" + shows "IFNTT (FNTT numbers) = map ((*) (of_int_mod_ring (int n))) numbers" + by (simp add: FNTT_correct IFNTT_correct assms length_NTT ntt_correct) + +text \The other way round:\ + +theorem FNTT_inv_IFNTT: + assumes "length numbers = n" + shows "FNTT (IFNTT numbers) = map ((*) (of_int_mod_ring (int n))) numbers" +by (simp add: FNTT_correct IFNTT_correct assms inv_ntt_correct length_INTT) + +subsection \An Optimization\ +text \Currently, we extract elements on even and odd positions respectively by a list comprehension + over even and odd indices. +Due to the definition in Isabelle, an index access has linear time complexity. +This results in quadratic running time complexity for every level +in the recursion tree of the \textit{FNTT}. +In order to reach the $\mathcal{O}(n \log n)$ time bound, +we have find a better way of splitting the elements at even or odd indices respectively. +\ + +text \A core of this optimization is the $evens\text{-}odds$ function, + which splits the vectors in linear time.\ + +fun evens_odds::"bool \'b list \ 'b list" where +"evens_odds _ [] = []"| +"evens_odds True (x#xs)= (x# evens_odds False xs)"| +"evens_odds False (x#xs) = evens_odds True xs" + +lemma map_filter_shift: " map f (filter even [0.. x. f (x+1)) (filter odd [0.. x. f (x+1)) (filter even [0..A splitting by the $evens\text{-}odds$ function is +equivalent to the more textbook-like list comprehension.\ + +lemma filter_compehension_evens_odds: + "[xs ! i. i <- filter even [0.. + [xs ! i. i <- filter odd [0..For automated termination proof.\ + +lemma [simp]: "length (evens_odds True vc) < Suc (length vc)" + "length (evens_odds False vc) < Suc (length vc)" + by (metis filter_compehension_evens_odds le_imp_less_Suc length_filter_le length_map map_nth)+ + + +text \The $FNTT$ definition from above was suitable for matters of proof conduction. + However, the naive decomposition into elements at odd and even indices induces a complexity of $n^2$ in every recursive step. +As mentioned, the $evens\text{-}odds$ function filters for elements on even or odd positions respectively. +The list has to be traversed only once which gives \textit{linear} complexity for every recursive step. \ + +fun FNTT' where +"FNTT' [] = []"| +"FNTT' [a] = [a]"| +"FNTT' nums = (let nn = length nums; + nums1 = evens_odds True nums; + nums2 = evens_odds False nums; + fntt1 = FNTT' nums1; + fntt2 = FNTT' nums2; + fntt2_omg = (map2 ( \ x k. x*(\^( (n div nn) * k))) fntt2 [0..<(nn div 2)]); + sum1 = map2 (+) fntt1 fntt2_omg; + sum2 = map2 (-) fntt1 fntt2_omg + in sum1@sum2)" + +text \The optimized \textit{FNTT} is equivalent to the naive \textit{NTT}.\ + +lemma FNTT'_FNTT: "FNTT' xs = FNTT xs" + apply(induction xs rule: FNTT'.induct) + subgoal by simp + subgoal by simp + apply(subst FNTT'.simps(3)) + apply(subst FNTT.simps(3)) + subgoal for a b xs + unfolding Let_def + apply (metis filter_compehension_evens_odds) + done + done + +text \It is quite surprising that some inaccuracies in the interpretation of informal textbook definitions +- even when just considering such a simple algorithm - can indeed affect time complexity.\ + +subsection \Arguments on Running Time\ + +text \ $FFT$ is especially known for its $\mathcal{O}(n \log n)$ running time. +Unfortunately, Isabelle does not provide a built-in time formalization. +Nonetheless we can reason about running time after defining some "reasonable" consumption functions by hand. +Our approach loosely follows a general pattern by Nipkow et al.~\parencite{funalgs}. +First, we give running times and lemmas for the auxiliary functions used during FNTT.\\ +General ideas behind the $\mathcal{O}(n \log n)$ are: +\begin{itemize} +\item By recursively halving the problem size, we obtain a tree of depth $\mathcal{O}(\log n)$. +\item For every level of that tree, we have to process all elements which gives $\mathcal{O}(n)$ time. +\end{itemize} + +\ + +text \Time for splitting the list according to even and odd indices.\ + +fun T_\<^sub>e\<^sub>o::"bool \ 'c list \ nat" where +" T_\<^sub>e\<^sub>o _ [] = 1"| +" T_\<^sub>e\<^sub>o True (x#xs)= (1+ T_\<^sub>e\<^sub>o False xs)"| +" T_\<^sub>e\<^sub>o False (x#xs) = (1+ T_\<^sub>e\<^sub>o True xs)" + +lemma T_eo_linear: "T_\<^sub>e\<^sub>o b xs = length xs + 1" + by (induction b xs rule: T_\<^sub>e\<^sub>o.induct) auto + +text \Time for length.\ + +fun T\<^sub>l\<^sub>e\<^sub>n\<^sub>g\<^sub>t\<^sub>h where +"T\<^sub>l\<^sub>e\<^sub>n\<^sub>g\<^sub>t\<^sub>h [] = 1 "| +"T\<^sub>l\<^sub>e\<^sub>n\<^sub>g\<^sub>t\<^sub>h (x#xs) = 1+ T\<^sub>l\<^sub>e\<^sub>n\<^sub>g\<^sub>t\<^sub>h xs" + +lemma T_length_linear: "T\<^sub>l\<^sub>e\<^sub>n\<^sub>g\<^sub>t\<^sub>h xs = length xs +1" + by (induction xs) auto + +text \Time for index access.\ + +fun T\<^sub>n\<^sub>t\<^sub>h where +"T\<^sub>n\<^sub>t\<^sub>h [] i = 1 "| +"T\<^sub>n\<^sub>t\<^sub>h (x#xs) 0 = 1"| +"T\<^sub>n\<^sub>t\<^sub>h (x#xs) (Suc i) = 1 + T\<^sub>n\<^sub>t\<^sub>h xs i" + +lemma T_nth_linear: "T\<^sub>n\<^sub>t\<^sub>h xs i \ length xs +1" + by (induction xs i rule: T\<^sub>n\<^sub>t\<^sub>h.induct) auto + +text \Time for mapping two lists into one result.\ + +fun T\<^sub>m\<^sub>a\<^sub>p\<^sub>2 where + "T\<^sub>m\<^sub>a\<^sub>p\<^sub>2 t [] _ = 1"| + "T\<^sub>m\<^sub>a\<^sub>p\<^sub>2 t _ [] = 1"| + "T\<^sub>m\<^sub>a\<^sub>p\<^sub>2 t (x#xs) (y#ys) = (t x y + 1 + T\<^sub>m\<^sub>a\<^sub>p\<^sub>2 t xs ys)" + +lemma T_map_2_linear: +"c > 0 \ + (\ x y. t x y \ c) \ T\<^sub>m\<^sub>a\<^sub>p\<^sub>2 t xs ys \ min (length xs) (length ys) * (c+1) + 1" + apply(induction t xs ys rule: T\<^sub>m\<^sub>a\<^sub>p\<^sub>2.induct) + subgoal by simp + subgoal by simp + subgoal for t x xs y ys + apply(subst T\<^sub>m\<^sub>a\<^sub>p\<^sub>2.simps, subst length_Cons, subst length_Cons) + using min_add_distrib_right[of 1] + by (smt (z3) Suc_eq_plus1 add.assoc add.commute add_le_mono le_numeral_extra(4) min_def mult.commute mult_Suc_right) + done + +lemma T_map_2_linear': +"c > 0 \ + (\ x y. t x y = c) \ T\<^sub>m\<^sub>a\<^sub>p\<^sub>2 t xs ys = min (length xs) (length ys) * (c+1) + 1" + by(induction t xs ys rule: T\<^sub>m\<^sub>a\<^sub>p\<^sub>2.induct) simp+ + + +text \Time for append.\ + +fun T\<^sub>a\<^sub>p\<^sub>p where + " T\<^sub>a\<^sub>p\<^sub>p [] _ = 1"| + " T\<^sub>a\<^sub>p\<^sub>p (x#xs) ys = 1 + T\<^sub>a\<^sub>p\<^sub>p xs ys" + +lemma T_app_linear: " T\<^sub>a\<^sub>p\<^sub>p xs ys = length xs +1" + by(induction xs) auto + + +text \Running Time of (optimized) $FNTT$.\ + +fun T\<^sub>F\<^sub>N\<^sub>T\<^sub>T::"('a mod_ring) list \ nat" where +"T\<^sub>F\<^sub>N\<^sub>T\<^sub>T [] = 1"| +"T\<^sub>F\<^sub>N\<^sub>T\<^sub>T [a] = 1"| +"T\<^sub>F\<^sub>N\<^sub>T\<^sub>T nums = (1 +T\<^sub>l\<^sub>e\<^sub>n\<^sub>g\<^sub>t\<^sub>h nums+ 3+ + + (let nn = length nums; + nums1 = evens_odds True nums; + nums2 = evens_odds False nums + in + T_\<^sub>e\<^sub>o True nums + T_\<^sub>e\<^sub>o False nums + 2 + + (let + fntt1 = FNTT nums1; + fntt2 = FNTT nums2 + in + (T\<^sub>F\<^sub>N\<^sub>T\<^sub>T nums1) + (T\<^sub>F\<^sub>N\<^sub>T\<^sub>T nums2) + + (let + sum1 = map2 (+) fntt1 (map2 ( \ x k. x*(\^( (n div nn) * k))) fntt2 [0..<(nn div 2)]); + sum2 = map2 (-) fntt1 (map2 ( \ x k. x*(\^( (n div nn) * k))) fntt2 [0..<(nn div 2)]) + in + 2* T\<^sub>m\<^sub>a\<^sub>p\<^sub>2 (\ x y. 1) fntt2 [0..<(nn div 2)] + + 2* T\<^sub>m\<^sub>a\<^sub>p\<^sub>2 (\ x y. 1) fntt1 (map2 ( \ x k. x*(\^( (n div nn) * k))) fntt2 [0..<(nn div 2)]) + + T\<^sub>a\<^sub>p\<^sub>p sum1 sum2))))" + +lemma mono: "((f x)::nat) \ f y \ f y \ fz \ f x \ fz" by simp + +lemma evens_odds_length: + "length (evens_odds True xs) = (length xs+1) div 2 \ + length (evens_odds False xs) = (length xs) div 2" + by(induction xs) simp+ + +text \Length preservation during $FNTT$.\ + +lemma FNTT_length: "length numbers = 2^l \ length (FNTT numbers) = length numbers" +proof(induction l arbitrary: numbers) + case (Suc l) + define numbers1 where "numbers1 = [numbers!i . i <- (filter even [0.. x k. x*(\^( (n div (length numbers)) * k))) + fntt2 [0..<((length numbers) div 2)])" + define sum1 where + "sum1 = map2 (+) fntt1 presum" + define sum2 where + "sum2 = map2 (-) fntt1 presum" + have "length numbers1 = 2^l" + by (metis Suc.prems numbers1_def diff_add_inverse2 length_even_filter mult_2 nonzero_mult_div_cancel_left power_Suc zero_neq_numeral) + hence "length fntt1 = 2^l" + by (simp add: Suc.IH fntt1_def) + hence "length presum = 2^l" unfolding presum_def + using map2_length Suc.IH Suc.prems fntt2_def length_odd_filter numbers2_def by force + hence "length sum1 = 2^l" + by (simp add: \length fntt1 = 2 ^ l\ sum1_def) + have "length numbers2 = 2^l" + by (metis Suc.prems numbers2_def length_odd_filter nonzero_mult_div_cancel_left power_Suc zero_neq_numeral) + hence "length fntt2 = 2^l" + by (simp add: Suc.IH fntt2_def) + hence "length sum2 = 2^l" unfolding sum2_def + using \length sum1 = 2 ^ l\ sum1_def by force + hence final:"length (sum1@sum2) = 2^(Suc l)" + by (simp add: \length sum1 = 2 ^ l\) + obtain x y xs where xyxs_Def: "numbers = x#y#xs" + by (metis \length numbers2 = 2 ^ l\ evens_odds.elims filter_compehension_evens_odds length_0_conv neq_Nil_conv numbers2_def power_eq_0_iff zero_neq_numeral) + show ?case + apply(subst xyxs_Def, subst FNTT.simps(3), subst xyxs_Def[symmetric]) + unfolding Let_def + using final + unfolding sum1_def sum2_def presum_def fntt1_def fntt2_def numbers1_def numbers2_def + using Suc by (metis xyxs_Def) +qed (metis FNTT.simps(2) Suc_length_conv length_0_conv nat_power_eq_Suc_0_iff) + +lemma add_cong: "(a1::nat) + a2+a3 +a4= b \ a1 +a2+ c + a3+a4= c +b" + by simp + +lemma add_mono:"a \ (b::nat) \ c \ d \ a + c \ b +d" by simp + +lemma xyz: " Suc (Suc (length xs)) = 2 ^ l \ length (x # evens_odds True xs) = 2 ^ (l - 1)" + by (metis (no_types, lifting) Nat.add_0_right Suc_eq_plus1 div2_Suc_Suc div_mult_self2 evens_odds_length length_Cons nat.distinct(1) numeral_2_eq_2 one_div_two_eq_zero plus_1_eq_Suc power_eq_if) + +lemma zyx:" Suc (Suc (length xs)) = 2 ^ l \ length (y # evens_odds False xs) = 2 ^ (l - 1)" + by (smt (z3) One_nat_def Suc_pred diff_Suc_1 div2_Suc_Suc evens_odds_length le_numeral_extra(4) length_Cons nat_less_le neq0_conv power_0 power_diff power_one_right zero_less_diff zero_neq_numeral) + +text \When $length \; xs = 2^l$, then $length \; (evens\text{-}odds \; xs) = 2^{l-1}$.\ + +lemma evens_odds_power_2: + fixes x::'b and y::'b + assumes "Suc (Suc (length (xs::'b list))) = 2 ^ l" + shows " Suc(length (evens_odds b xs)) = 2 ^ (l-1)" +proof- + have "Suc(length (evens_odds b xs)) = length (evens_odds b (x#y#xs))" + by (metis (full_types) evens_odds.simps(2) evens_odds.simps(3) length_Cons) + have "length (x#y#xs) = 2^l" using assms by simp + have "length (evens_odds b (x#y#xs)) = 2^(l-1)" + apply (cases b) + apply (smt (z3) Suc_eq_plus1 Suc_pred \length (x # y # xs) = 2 ^ l\ add.commute add_diff_cancel_left' assms filter_compehension_evens_odds gr0I le_add1 le_imp_less_Suc length_even_filter mult_2 nat_less_le power_diff power_eq_if power_one_right zero_neq_numeral) + by (smt (z3) One_nat_def Suc_inject \length (x # y # xs) = 2 ^ l\ assms evens_odds_length le_zero_eq nat.distinct(1) neq0_conv not_less_eq_eq pos2 power_Suc0_right power_diff_power_eq power_eq_if) + then show ?thesis + by (metis \Suc (length (evens_odds b xs)) = length (evens_odds b (x # y # xs))\) +qed + +text \ \noindent \textbf{Major Lemma:} We rewrite the Running time of $FNTT$ in this proof and collect constraints for the time bound. +Using this, bounds are chosen in a way such that the induction goes through properly. +\paragraph \noindent We define: + +\begin{equation*} +T(2^0) = 1 +\end{equation*} + +\begin{equation*} +T(2^l) = +(2^l - 1)\cdot 14 apply+ 15 \cdot l \cdot 2^{l-1} + 2^l +\end{equation*} + +\paragraph \noindent We want to show: + +\begin{equation*} +T_{FNTT}(2^l) = T(2^l) +\end{equation*} + +(Note that by abuse of types, the $2^l$ denotes a list of length $2^l$.) + +\paragraph \noindent First, let's informally check that $T$ is indeed an accurate description of the running time: + +\begin{align*} +T_{FNTT}(2^l) & \; = 14 + 15 \cdot 2 ^ {l-1} + 2 \cdot T_{FNTT}(2^{l-1}) \hspace{1cm} \text{by analyzing the running time function}\\ +&\overset{I.H.}{=} 14 + 15 \cdot 2 ^ {l-1} + 2 \cdot ((2^{l-1} - 1) \cdot 14 + (l - 1) \cdot 15 \cdot 2^{l-2} + 2^{l-1})\\ +& \;= 14 \cdot 2^l - 14 + 15 \cdot 2 ^ {l-1} + 15\cdot l \cdot 2^{l-1} - 15 \cdot 2^{l-1} + 2^l\\ +&\; = (2^l - 1)\cdot 14 + 15 \cdot l \cdot 2^{l-1} + 2^l\\ +&\overset{def.}{=} T(2^l) +\end{align*} + +The base case is trivially true. +\ + +theorem tight_bound: + assumes T_def: "\ numbers l. length numbers = 2^l \ l > 0 \ + T numbers = (2^l - 1) * 14 + l *15*2^(l-1) + 2^l" + "\ numbers l. l =0 \ length numbers = 2^l \ T numbers = 1" + shows " length numbers = 2^l \ T\<^sub>F\<^sub>N\<^sub>T\<^sub>T numbers = T numbers" +proof(induction numbers arbitrary: l rule: T\<^sub>F\<^sub>N\<^sub>T\<^sub>T.induct) + case (3 x y numbers) + + text \Some definitions for making term rewriting simpler.\ + + define nn where "nn = length (x # y # numbers)" + define nums1 where "nums1 = evens_odds True (x # y # numbers)" + define nums2 where "nums2 = evens_odds False (x # y # numbers)" + define fntt1 where "fntt1 = local.FNTT nums1" + define fntt2 where "fntt2 = local.FNTT nums2" + define sum1 where "sum1 = map2 (+) fntt1 (map2 (\x y. x * \ ^ (n div nn * y)) fntt2 [0..x y. x * \ ^ (n div nn * y)) fntt2 [0..Unfolding the running time function and combining it with the definitions above.\ + + have TFNNT_simp: " T\<^sub>F\<^sub>N\<^sub>T\<^sub>T (x # y # numbers) = + 1 + T\<^sub>l\<^sub>e\<^sub>n\<^sub>g\<^sub>t\<^sub>h (x # y # numbers) + 3 + + T_\<^sub>e\<^sub>o True (x # y # numbers) + T_\<^sub>e\<^sub>o False (x # y # numbers) + 2 + + local.T\<^sub>F\<^sub>N\<^sub>T\<^sub>T nums1 + local.T\<^sub>F\<^sub>N\<^sub>T\<^sub>T nums2 + + 2 * T\<^sub>m\<^sub>a\<^sub>p\<^sub>2 (\x y. 1) fntt2 [0..m\<^sub>a\<^sub>p\<^sub>2 (\x y. 1) fntt1 (map2 (\x y. x * \ ^ (n div nn * y)) fntt2 [0..a\<^sub>p\<^sub>p sum1 sum2" + apply(subst T\<^sub>F\<^sub>N\<^sub>T\<^sub>T.simps(3)) + unfolding Let_def unfolding sum2_def sum1_def fntt1_def fntt2_def nums1_def nums2_def nn_def + apply simp + done + + text \Application of lemmas related to running times of auxiliary functions.\ + + have length_nums1: "length nums1 = (2::nat)^(l-1)" + unfolding nums1_def + using evens_odds_length[of "x # y # numbers"] 3(3) xyz by fastforce + have length_nums2: "length nums2 = (2::nat)^(l-1)" + unfolding nums2_def + using evens_odds_length[of "x # y # numbers"] 3(3) + by (metis One_nat_def le_0_eq length_Cons lessI list.size(4) neq0_conv not_add_less2 not_less_eq_eq pos2 power_Suc0_right power_diff_power_eq power_eq_if) + have length_simp: "T\<^sub>l\<^sub>e\<^sub>n\<^sub>g\<^sub>t\<^sub>h (x # y # numbers) = (2::nat) ^l +1" + using T_length_linear[of "x#y#numbers"] 3(3) by simp + have even_odd_simp: " T_\<^sub>e\<^sub>o b (x # y # numbers) = (2::nat)^l + 1" for b + by (metis "3.prems" T_eo_linear)+ + have 02: "(length fntt2) = (length [0..m\<^sub>a\<^sub>p\<^sub>2 (\x y. 1) fntt2 [0..m\<^sub>a\<^sub>p\<^sub>2 (\x y. 1) fntt1 (map2 (\x y. x * \ ^ (n div nn * y)) fntt2 [0..a\<^sub>p\<^sub>p sum1 sum2 = (2::nat)^(l-1) + 1" + by(subst T_app_linear, subst sum1_simp, simp) + let ?T1 = "(2^(l-1) - 1) * 14 + (l-1) *15*2^(l-1 -1) + 2^(l-1)" + + text \Induction hypotheses\ + + have IH_pluged1: "local.T\<^sub>F\<^sub>N\<^sub>T\<^sub>T nums1 = ?T1" + apply(subst "3.IH"(1)[of nn nums1 nums2 fntt1 fntt2 "l-1", + OF nn_def nums1_def nums2_def fntt1_def fntt2_def length_nums1]) + apply(cases "l \ 1") + subgoal + apply(subst T_def(2)[of "l-1"]) + subgoal by simp + apply(rule length_nums1) + apply simp + done + apply(subst T_def(1)[OF length_nums1]) + subgoal by simp + subgoal by simp + done + + have IH_pluged2: "local.T\<^sub>F\<^sub>N\<^sub>T\<^sub>T nums2 = ?T1" + apply(subst "3.IH"(2)[of nn nums1 _ fntt1 fntt2 "l-1", OF nn_def nums1_def nums2_def fntt1_def + fntt2_def length_nums2 ]) + apply(cases "l \ 1") + subgoal + apply(subst T_def(2)[of "l-1"]) + subgoal by simp + apply(rule length_nums2) + apply simp + done + apply(subst T_def(1)[OF length_nums2]) + subgoal by simp + subgoal by simp + done + + have " T\<^sub>F\<^sub>N\<^sub>T\<^sub>T (x # y # numbers) = + 14 + (3 * 2 ^ l + (local.T\<^sub>F\<^sub>N\<^sub>T\<^sub>T nums1 + + (local.T\<^sub>F\<^sub>N\<^sub>T\<^sub>T nums2 + (5 * 2^(l-1) + 4 * (2 ^ l div 2)))))" + apply(subst TFNNT_simp, subst map21_simp, subst map22_simp, subst length_simp, + subst app_simp, subst even_odd_simp, subst even_odd_simp) + apply(auto simp add: algebra_simps power_eq_if[of 2 l]) + done + + text \Proof that the term $T\text{-}def$ indeed fulfills the recursive properties, i.e. + $t(2^l) = 2 \cdot t(2^{l-1}) + s$\ + + also have "\ = 14 + (3 * 2 ^ l + (?T1 + (?T1 + (5 * 2^(l-1) + 4 * (2 ^ l div 2)))))" + apply(subst IH_pluged1, subst IH_pluged2) + by simp + also have "\ = 14 + (6 * 2 ^ (l-1) + + 2*((2 ^ (l - 1) - 1) * 14 + (l - 1) * 15 * 2 ^ (l - 1 - 1) + 2 ^ (l - 1)) + + (5 * 2 ^ (l - 1) + 4 * (2 ^ l div 2)))" + by (smt (verit) "3"(3) add.assoc div_less evens_odds_length left_add_twice length_nums2 lessI mult.assoc mult_2_right nat_1_add_1 numeral_Bit0 nums2_def plus_1_eq_Suc power_eq_if power_not_zero zero_neq_numeral) + also have "\ = 14 + 15 * 2 ^ (l-1) + + 2*((2 ^ (l - 1) - 1) * 14 + (l - 1) * 15 * 2 ^ (l - 1 - 1) + 2 ^ (l - 1))" + by (smt (z3) "3"(3) add.assoc add.commute calculation diff_diff_left distrib_left div2_Suc_Suc evens_odds_length left_add_twice length_Cons length_nums2 mult.assoc mult.commute mult_2 mult_2_right numeral_Bit0 numeral_Bit1 numeral_plus_numeral nums2_def one_add_one) + also have "... = 14 + 15 * 2 ^ (l-1) + + (2 ^ l - 2) * 14 + (l - 1) * 15 * 2 ^ (l - 1) + 2 ^ l" + apply(cases "l > 1") + apply (smt (verit, del_insts) add.assoc diff_is_0_eq distrib_left_numeral left_diff_distrib' less_imp_le_nat mult.assoc mult_2 mult_2_right nat_1_add_1 not_le not_one_le_zero power_eq_if) + by (smt (z3) "3"(3) add.commute add.right_neutral cancel_comm_monoid_add_class.diff_cancel diff_add_inverse2 diff_is_0_eq div_less_dividend evens_odds_length length_nums2 mult_2 mult_eq_0_iff nat_1_add_1 not_le nums2_def power_eq_if) + also have "\ = 15 * 2 ^ (l - 1) + (2 ^ l - 1) * 14 + (l - 1) * 15 * 2 ^ (l - 1) + 2 ^ l" + by (smt (z3) "3"(3) One_nat_def add.commute combine_common_factor diff_add_inverse2 diff_diff_left list.size(4) nat_1_add_1 nat_mult_1) + also have "\ = (2^l - 1) * 14 + l *15*2^(l-1) + 2^l" + apply(cases "l > 0") + subgoal using group_cancel.add1 group_cancel.add2 less_numeral_extra(3) mult.assoc mult_eq_if by auto[1] + using "3"(3) by fastforce + + text \By the previous proposition, we can conclude that $T$ is indeed a suitable term for describing the running time\ + + finally have "T\<^sub>F\<^sub>N\<^sub>T\<^sub>T (x # y # numbers) = T (x # y # numbers)" + using T_def(1)[of "x#y#numbers" l] + by (metis "3.prems" bits_1_div_2 diff_is_0_eq' evens_odds_length length_nums2 neq0_conv nums2_def power_0 zero_le_one zero_neq_one) + thus ?case by simp +qed (auto simp add: assms) + +text \We can finally state that $FNTT$ has $\mathcal{O}(n \log n)$ time complexity.\ + +theorem log_lin_time: + assumes "length numbers = 2^l" + shows "T\<^sub>F\<^sub>N\<^sub>T\<^sub>T numbers \ 30 * l * length numbers + 1" +proof- + have 00: "T\<^sub>F\<^sub>N\<^sub>T\<^sub>T numbers = (2 ^ l - 1) * 14 + l * 15 * 2 ^ (l - 1) + 2 ^ l" + using tight_bound[of "\ xs. (length xs - 1) * 14 + (Discrete.log (length xs)) * 15 * + 2 ^ ( (Discrete.log (length xs)) - 1) + length xs" numbers l] + assms by simp + have " l * 15 * 2 ^ (l - 1) \ 15 * l * length numbers" using assms by simp + moreover have "(2 ^ l - 1) * 14 + 2^l\ 15 * length numbers " + using assms by linarith + moreover hence "(2 ^ l - 1) * 14 + 2^l \ 15 * l * length numbers +1" using assms + apply(cases l) + subgoal by simp + by (metis (no_types) add.commute le_add1 mult.assoc mult.commute + mult_le_mono nat_mult_1 plus_1_eq_Suc trans_le_add2) + ultimately have " (2 ^ l - 1) * 14 + l * 15 * 2 ^ (l - 1) + 2 ^ l \ 30 * l * length numbers +1" + by linarith + then show ?thesis using 00 by simp +qed + +theorem log_lin_time_explicitly: + assumes "length numbers = 2^l" + shows "T\<^sub>F\<^sub>N\<^sub>T\<^sub>T numbers \ 30 * Discrete.log (length numbers) * length numbers + 1" + using log_lin_time[of numbers l] assms by simp + +end +end diff --git a/thys/Number_Theoretic_Transform/NTT.thy b/thys/Number_Theoretic_Transform/NTT.thy new file mode 100644 --- /dev/null +++ b/thys/Number_Theoretic_Transform/NTT.thy @@ -0,0 +1,472 @@ +(* +Title: Number Theoretic Transform +Author: Thomas Ammer +*) + +theory NTT + imports Preliminary_Lemmas +begin + +section \Number Theoretic Transform and Inverse Transform\ +text \\label{NTT}\ + +locale ntt = preliminary "TYPE ('a ::prime_card)" + +fixes \ :: "('a::prime_card mod_ring)" +fixes \ :: "('a mod_ring)" +assumes omega_properties: "\^n = 1" "\ \ 1" "(\ m. \^m = 1 \ m\0 \ m \ n)" +assumes mu_properties: "\ * \ = 1" +begin + +lemma mu_properties': "\ \ 1" + using omega_properties mu_properties by auto + +subsection \Definition of $NTT$ and $INTT$\ +text \\label{NTTdef}\ + +text \ +Now we can state an analogue to the $DFT$ on finite fields, +namely the \textit{Number Theoretic Transform}. +First, let us look at an informal definition of $\mathsf{NTT}$~\parencite{ntt_intro}: +\begin{equation*} +\mathsf{NTT}(\vec{x}) = +\begin{pmatrix} + 1 & 1 & 1 & 1 & \cdots& 1 \\ + 1 & \omega & \omega^2 & \omega^3 & \cdots & \omega^{n-1} \\ + 1 & \omega^2 & \omega^4 & \omega^6 & \cdots & \omega^{2\cdot(n-1)} \\ + 1 & \omega^3 & \omega^6 & \omega^9 & \cdots & \omega^{3\cdot(n-1)} \\ +\vdots & \vdots & \vdots & \vdots & & \vdots \\ + 1 & \omega^{n-1} & \omega^{2\cdot(n-1)} & \omega^{3\cdot(n-1)} & \cdots & \omega^{(n-1)\cdot(n-1)} +\end{pmatrix} \cdot \vec{x} +\end{equation*} + +Or for single vector entries: +\begin{equation*} +\mathsf{NTT}(\vec{x})_i = \sum _{j = 0} ^{n-1} x_j \cdot \omega ^{i\cdot j} +\end{equation*} + +\ + +text \Formally:\ + +definition ntt::"(('a ::prime_card) mod_ring) list \ nat \ 'a mod_ring" where +"ntt numbers i = (\j=0..^(i*j)) " + +definition "NTT numbers = map (ntt numbers) [0..\label{INTTdef} +We define the inverse transform $\mathsf{INTT}$ by matrices: +\begin{equation*} + \mathsf{INTT}(\vec{y}) = +\begin{pmatrix} + 1 & 1 & 1 & 1 & \cdots& 1 \\ + 1 & \mu & \mu^2 & \mu^3 & \cdots & \mu^{n-1} \\ + 1 & \mu^2 & \mu^4 & \mu^6 & \cdots & \mu^{2\cdot(n-1)} \\ + 1 & \mu^3 & \mu^6 & \mu^9 & \cdots & \mu^{3\cdot(n-1)} \\ +\vdots & \vdots & \vdots & \vdots & & \vdots \\ + 1 & \mu^{n-1} & \mu^{2\cdot(n-1)} & \mu^{3\cdot(n-1)} & \cdots & \mu^{(n-1)\cdot(n-1)} +\end{pmatrix} \cdot \vec{y} +\end{equation*} +Per component: +\begin{equation*} +% +\mathsf{INTT}(\vec{y})_i = \sum _{j = 0} ^{n-1} y_j \cdot \mu ^{i\cdot j} +% +\end{equation*} + +\ + +definition "intt xs i = (\j=0..^(i*j)) " + +definition "INTT xs = map (intt xs) [0..Vector length is preserved.\ + +lemma length_NTT: + assumes n_def: "length numbers = n" + shows "length (NTT numbers) = n" + unfolding NTT_def ntt_def using n_def length_map[of _ "[0..Correctness Proof of $NTT$ and $INTT$\ +text \\label{NTTcorr}\ +text \ +We prove $\mathsf{NTT}$ and $\mathsf{INTT}$ correct: +By taking $\mathsf{INTT}(\mathsf{NTT} (x))$ we obtain $x$ scaled by $n$. +Analogue to $DFT$, one can get rid of the factor $n$ by a simple rescaling. +First, consider an informal proof sketch using the matrix form: +\begin{equation*} +\begin{split} +\mathsf{INTT}(\mathsf{NTT}(\vec{x})) = \hspace{11cm}\\ +% +\begin{pmatrix} + 1 & 1 & 1 & \cdots& 1 \\ + 1 & \mu & \mu^2 & \cdots & \mu^{n-1} \\ + 1 & \mu^2 & \mu^4 & \cdots & \mu^{2\cdot(n-1)} \\ +\vdots & \vdots & \vdots & & \vdots \\ + 1 & \mu^{n-1} & \mu^{2\cdot(n-1)}& \cdots & \mu^{(n-1)\cdot(n-1)} +\end{pmatrix} +% +\cdot +% +\begin{pmatrix} + 1 & 1 & 1 & \cdots& 1 \\ + 1 & \omega & \omega^2 & \cdots & \omega^{n-1} \\ + 1 & \omega^2 & \omega^4 & \cdots & \omega^{2\cdot(n-1)} \\ +\vdots & \vdots & \vdots & & \vdots \\ + 1 & \omega^{n-1} & \omega^{2\cdot(n-1)} & \cdots & \omega^{(n-1)\cdot(n-1)} +\end{pmatrix} +\cdot +% +\vec{x} +% +\end{split} +\end{equation*} + +A resulting entry is of the following form: + +\begin{equation*} +% +\mathsf{INTT}(\mathsf{NTT}(x))_i = \sum _ {j = 0} ^{n-1} (\sum _{k = 0} ^{n-1} \mu^{i\cdot k} \cdot \omega^{j\cdot k}) \cdot x_j +% +\end{equation*} + +Now, we analyze the interior sum by cases on $i = j$. + +\paragraph \noindent Case $i = j$. +\begin{align*} +\sum _{k = 0} ^{n-1} \mu^{i\cdot k} \cdot \omega^{j\cdot k} +&= \sum _{k = 0} ^{n-1} (\mu \cdot \omega)^{i \cdot k} \\ +&= n \cdot (\mu \cdot \omega)^{i \cdot k} \\ +&= n \cdot 1 ^{i \cdot k} \\ &= n +\end{align*} +Note that $\omega$ and $\mu$ are mutually inverse. +\paragraph \noindent Case $i \neq j$. Wlog assume $i > j$, otherwise replace $\omega$ by $\mu$ and $i -j$ by $j - i$ respectively. +\begin{align*} +\sum _{k = 0} ^{n-1} \mu^{i\cdot k} \cdot \omega^{j\cdot k} +&= \sum _{k = 0} ^{n-1} (\mu \cdot \omega)^{j \cdot k} \cdot \omega^{(i-j) \cdot k} \\ +&= \sum _{k = 0} ^{n-1} \omega^{(i-j) \cdot k} \\ +&= (1 - \omega ^{(i-j)\cdot n}) \cdot (1 - \omega^{i-j})^{-1} && \text{by lemma on geometric sum}\\ +&= (1 - 1^n) \cdot (1 - \omega^{i-j})^{-1} \\ +&= 0 +\end{align*} + +We conclude that $\sum \limits _ {j = 0} ^{n-1} (\sum \limits _{k = 0} ^{n-1} \mu^{i\cdot k} \cdot \omega^{j\cdot k}) \cdot x_j = n \cdot x_i$. + +\ + +theorem ntt_correct: + assumes n_def: "length numbers = n" + shows "INTT (NTT numbers) = map (\ x. (of_int_mod_ring n) * x ) numbers" +proof- + have 0:"\ i. i < n \ (INTT (NTT numbers)) ! i = intt (NTT numbers) i " using n_def length_NTT + unfolding INTT_def NTT_def intt_def by simp + + text \Major sublemma.\ + + have 1:"\ i. i < n \intt (NTT numbers) i = (of_int_mod_ring n)*numbers ! i" + proof- + fix i + assume i_assms:"i < n" + + text \First, simplify by some chains of equations.\ + + hence 1:"intt (NTT numbers) i = + (\l = 0..j = 0.. ^ (l * j)) * \ ^ (i * l))" + unfolding NTT_def intt_def ntt_def using n_def length_map nth_map by simp + also have 2:"\ = + (\l = 0..j = 0.. ^ (l * j)) * \ ^ (i * l)))" + using sum_in by (simp add: sum_distrib_right) + also have 3:"\ = + (\j = 0..l = 0.. ^ (l * j) * \ ^ (i * l))))" using sum_swap by fast + + text \As in the informal proof, we consider three cases. First $j = i$.\ + + have iisj:"\ j. j = i \ (\l = 0.. ^ (l * j) * \ ^ (i * l))) = (numbers ! j)* (of_int_mod_ring n)" + proof- + fix j + assume "j=i" + hence "\ l. l < n \ (numbers ! j * \ ^ (l * j) * \ ^ (i * l))= (numbers ! j)" + by (simp add: left_right_inverse_power mult.commute mu_properties(1)) + moreover have "\ l. l < n \ numbers ! j * \ ^ (l * j) * \ ^ (i * l) = numbers ! j" + using calculation by blast + + text \$\omega^{il}\cdot \omega^{jl} = 1$. Thus, we sum over $1$ $n$ times, which gives the goal.\ + + ultimately show "(\l = 0.. ^ (l * j) * \ ^ (i * l))) = + (numbers ! j)* (of_int_mod_ring n)" + using n_def sum_const[of "numbers ! j" n] exp_rule[of \ \] mu_properties(1) + by (metis (no_types, lifting) atLeastLessThan_iff mult.commute sum.cong) + + qed + + text \Case $j < i$.\ + + have jlsi:"\ j. j < i \ (\l = 0.. ^ (l * j) * \ ^ (i * l))) = 0" + proof- + fix j + assume j_assms:"j < i" + hence 00:"\ (c::('a::prime_card) mod_ring) a b. c * a^j*b^i = (a*b)^j*(c * b^(i-j))" + using algebra_simps + by (smt (z3) le_less ordered_cancel_comm_monoid_diff_class.add_diff_inverse power_add) + + text \A geometric sum over $\mu^l$ remains.\ + + have 01:" (\l = 0.. ^ (l * j) * \ ^ (i * l))) = + (\l = 0..^l)^(i-j)))" + apply(rule sum_eq) + using mu_properties(1) 00 algebra_simps(23) + by (smt (z3) mult.commute mult.left_neutral power_mult power_one) + have 02:"\ = numbers ! j *(\l = 0..^l)^(i-j))) " + using sum_in[of "\ l. numbers ! j * (\ ^ l) ^ (i - j)" " numbers ! j" n] + by (simp add: mult_hom.hom_sum) + moreover have 03:"(\l = 0..^l)^(i-j))) = + (\l = 0..^(i-j))^l)) " + by(rule sum_eq) (metis mult.commute power_mult) + have "\^(i-j) \ 1" + proof + assume "\ ^ (i - j) = 1" + hence "ord p (to_int_mod_ring \) \ i-j" + by (simp add: j_assms not_le ord_max) + moreover hence "ord p (to_int_mod_ring \) \ i-j" + by (metis \\ ^ (i - j) = 1\ diff_is_0_eq exp_rule j_assms leD mult.comm_neutral mult.commute mu_properties(1) ord_max) + moreover hence "i-j < n" + using j_assms i_assms p_fact k_bound n_lst2 by linarith + moreover have "ord p (to_int_mod_ring \) = n" using omega_properties n_lst2 unfolding ord_def + by (metis (no_types) \\ ^ (i - j) = 1\ calculation(3) diff_is_0_eq j_assms leD left_right_inverse_power mult.comm_neutral mult_cancel_left mu_properties(1) omega_properties(3) zero_neq_one) + ultimately show False by simp + qed + + text \Application of the lemma for geometric sums.\ + + ultimately have "(1-\^(i-j))*(\l = 0..^(i-j))^l)) = (1-(\^(i-j))^n)" + using geo_sum[of "\ ^ (i - j)" n] by simp + moreover have "(\^(i-j))^n = 1" + by (metis (no_types) left_right_inverse_power mult.commute mult.right_neutral mu_properties(1) omega_properties(1) power_mult power_one) + + text \The sum for the current index is 0.\ + + ultimately have "(\l = 0..^(i-j))^l)) = 0" + by (metis \\ ^ (i - j) \ 1\ divisors_zero eq_iff_diff_eq_0) + thus "(\l = 0.. ^ (l * j) * \ ^ (i * l)) = 0" using 01 02 03 by simp + qed + + text \Case $i < j$. + We also rewrite the whole summation until the lemma for geometric sums is applicable. + From this, we conclude that the term is 0.\ + + have ilsj:"\ j. i < j \ j < n \ (\l = 0.. ^ (l * j) * \ ^ (i * l))) = 0" + proof- + fix j + assume ij_Assm: "i < j \ j < n" + hence 00:"\ (c::('a::prime_card) mod_ring) a b. (a*b)^i*(c * b^(j-i)) = c * a^i*b^j " + by (auto simp: field_simps simp flip: power_add) + have 01:" (\l = 0.. ^ (l * j) * \ ^ (i * l))) = + (\l = 0..^l)^(j-i))) " + apply(rule sum_eq) subgoal for l + using mu_properties(1) 00[of "\^l" "\^l" "numbers ! j "] algebra_simps(23) + by (smt (z3) "00" left_right_inverse_power mult.assoc mult.commute mult.right_neutral power_mult) + done + moreover have 02:"(\l = 0..^l)^(j-i))) = + numbers ! j *(\l = 0..^l)^(j-i))) " + by (simp add: mult_hom.hom_sum) + moreover have 03:"(\l = 0..^l)^(j-i))) = + (\l = 0..^(j-i))^l))) " + by(rule sum_eq) (metis mult.commute power_mult) + have "\^(j-i) \ 1" + proof + assume " \ ^ (j - i) = 1" + hence "ord p (to_int_mod_ring \) \ j-i" using ord_max[of "j-i" \] ij_Assm by simp + moreover have "ord p (to_int_mod_ring \) =p-1" + by (meson \\ ^ (j - i) = 1\ diff_is_0_eq diff_le_self ij_Assm leD le_trans omega_properties(3)) + ultimately show False + by (meson \\ ^ (j - i) = 1\ diff_is_0_eq diff_le_self ij_Assm leD le_trans omega_properties(3)) + qed + + text \Geometric sum.\ + + ultimately have "(1-\^(j-i))* (\l = 0..^(j-i))^l)) = (1-(\^(j-i))^n)" + using geo_sum[of "\ ^ (j-i)" n] by simp + moreover have "(\^(j-i))^n = 1" + by (metis (no_types) mult.commute omega_properties(1) power_mult power_one) + ultimately have "(\l = 0..^(j-i))^l)) = 0" + by (metis \\ ^ (j - i) \ 1\ eq_iff_diff_eq_0 no_zero_divisors) + thus "(\l = 0.. ^ (l * j) * \ ^ (i * l)) = 0" using 01 02 03 by simp + qed + + text \We compose the cases $j i$ to a complete summation over index $j$.\ + + have " (\j = 0..l = 0.. ^ (l * j) * \ ^ (i * l)) = 0" using jlsi by simp + moreover have " (\j = i..l = 0.. ^ (l * j) * \ ^ (i * l)) = numbers ! i * (of_int_mod_ring n)" using iisj by simp + moreover have " (\j = (i+1)..l = 0.. ^ (l * j) * \ ^ (i * l)) = 0" using ilsj by simp + ultimately have " (\j = 0..l = 0.. ^ (l * j) * \ ^ (i * l)) = + numbers ! i * (of_int_mod_ring n)" using i_assms sum_split + by (smt (z3) add.commute add.left_neutral int_ops(2) less_imp_of_nat_less of_nat_add of_nat_eq_iff of_nat_less_imp_less) + + text \Index-wise equality can be shown.\ + + thus "intt (NTT numbers) i = of_int_mod_ring (int n) * numbers ! i" using 1 2 3 + by (metis mult.commute) + qed + have 2: "\ i. i < n \ (map ((*) (of_int_mod_ring (int n))) numbers ) ! i = (of_int_mod_ring (int n)) * (numbers ! i)" + by (simp add: n_def) + + text \We relate index-wise equality to the function definition.\ + + show ?thesis + apply(rule nth_equalityI) + subgoal my_subgoal + unfolding INTT_def NTT_def + apply (simp add: n_def) + done + subgoal for i + using 0 1 2 n_def algebra_simps my_subgoal length_map + apply auto + done + done +qed + +text \Now we prove the converse to be true: +$\mathsf{NTT}(\mathsf{INTT}(\vec{x})) = n \cdot \vec{x}$. +The proof proceeds analogously with exchanged roles of $\omega$ and $\mu$. +\ + +theorem inv_ntt_correct: + assumes n_def: "length numbers = n" + shows "NTT (INTT numbers) = map (\ x. (of_int_mod_ring n) * x ) numbers" +proof- + have 0:"\ i. i < n \ (NTT (INTT numbers)) ! i = ntt (INTT numbers) i " using n_def length_NTT + unfolding INTT_def NTT_def intt_def by simp + have 1:"\ i. i < n \ntt (INTT numbers) i = (of_int_mod_ring n)*numbers ! i" + proof- + fix i + assume i_assms:"i < n" + hence 1:"ntt (INTT numbers) i = + (\l = 0..j = 0.. ^ (l * j)) * \ ^ (i * l))" + unfolding INTT_def ntt_def intt_def using n_def length_map nth_map by simp + hence 2:"\ = (\l = 0..j = 0.. ^ (l * j)) * \ ^ (i * l)))" using sum_in by simp + have 3:" \ =(\j = 0..l = 0.. ^ (l * j) * \ ^ (i * l))))" using sum_swap by fast + have iisj:"\ j. j = i \ (\l = 0.. ^ (l * j) * \ ^ (i * l))) = (numbers ! j)* (of_int_mod_ring n)" + proof- + fix j + assume "j=i" + hence "\ l. l < n \ (numbers ! j * \ ^ (l * j) * \ ^ (i * l))= (numbers ! j)" + by (simp add: left_right_inverse_power mult.commute mu_properties(1)) + moreover have "\ l. l < n \ numbers ! j * \ ^ (l * j) * \ ^ (i * l) = numbers ! j" + using calculation by blast + ultimately show "(\l = 0.. ^ (l * j) * \ ^ (i * l))) = (numbers ! j)* (of_int_mod_ring n)" + using n_def sum_const[of "numbers ! j" n] exp_rule[of \ \] mu_properties(1) + by (metis (no_types, lifting) atLeastLessThan_iff mult.commute sum.cong) + qed + have jlsi:"\ j. j < i \ (\l = 0.. ^ (l * j) * \ ^ (i * l))) = 0" + proof- + fix j + assume j_assms:"j < i" + hence 00:"\ (c::('a::prime_card) mod_ring) a b. c * a^j*b^i = (a*b)^j*(c * b^(i-j))" + using algebra_simps + by (smt (z3) le_less ordered_cancel_comm_monoid_diff_class.add_diff_inverse power_add) + have 01:" (\l = 0.. ^ (l * j) * \ ^ (i * l))) = + (\l = 0..^l)^(i-j))) " + apply(rule sum_eq) + using mu_properties(1) 00 algebra_simps(23) + by (smt (z3) mult.commute mult.left_neutral power_mult power_one) + moreover have 02: "\= numbers ! j *(\l = 0..^l)^(i-j))) " + using sum_in[of "\ l. numbers ! j * (\ ^ l) ^ (i - j)" " numbers ! j" n] + by (simp add: mult_hom.hom_sum) + moreover have 03:"(\l = 0..^l)^(i-j))) = + (\l = 0..^(i-j))^l)) " + by(rule sum_eq) (metis mult.commute power_mult) + have "\^(i-j) \ 1" + proof + assume "\ ^ (i - j) = 1" + hence "ord p (to_int_mod_ring \) \ i-j" + by (simp add: j_assms not_le ord_max) + moreover have "ord p (to_int_mod_ring \) = n" using omega_properties n_lst2 unfolding ord_def + by (meson \\ ^ (i - j) = 1\ diff_is_0_eq diff_le_self i_assms j_assms leD le_trans) + ultimately show False + by (metis i_assms leD less_imp_diff_less) + qed + ultimately have "(1-\^(i-j))*(\l = 0..^(i-j))^l)) = (1-(\^(i-j))^n)" + using geo_sum[of "\ ^ (i - j)" n] by simp + moreover have "(\^(i-j))^n = 1" + by (metis (no_types) mult.commute omega_properties(1) power_mult power_one) + ultimately have "(\l = 0..^(i-j))^l)) = 0" + by (metis \\ ^ (i - j) \ 1\ divisors_zero eq_iff_diff_eq_0) + thus "(\l = 0.. ^ (l * j) * \ ^ (i * l)) = 0" using 01 02 03 by simp + qed + have ilsj:"\ j. i < j \ j < n \ (\l = 0.. ^ (l * j) * \ ^ (i * l))) = 0" + proof- + fix j + assume ij_Assm: "i < j \ j < n" + hence 00:"\ (c::('a::prime_card) mod_ring) a b. (a*b)^i*(c * b^(j-i)) = c * a^i*b^j " + by (simp add: field_simps flip: power_add) + have 01:" (\l = 0.. ^ (l * j) * \ ^ (i * l))) = + (\l = 0..^l)^(j-i))) " + apply(rule sum_eq) subgoal for l + using mu_properties(1) 00[of "\^l" "\^l" "numbers ! j "] algebra_simps(23) + by (smt (z3) "00" left_right_inverse_power mult.assoc mult.commute mult.right_neutral power_mult) + done + moreover have 02:"(\l = 0..^l)^(j-i))) = + numbers ! j *(\l = 0..^l)^(j-i))) " + by (simp add: mult_hom.hom_sum) + moreover have 03:"(\l = 0..^l)^(j-i))) = + (\l = 0..^(j-i))^l))) " + by(rule sum_eq) (metis mult.commute power_mult) + have "\^(j-i) \ 1" + proof + assume "\ ^ (j - i) = 1" + hence "ord p (to_int_mod_ring \) \ j -i " + by (simp add: ij_Assm not_le ord_max) + moreover hence "ord p (to_int_mod_ring \) \ j-i" + by (metis \\ ^ (j - i) = 1\ diff_is_0_eq exp_rule ij_Assm leD mult.comm_neutral mult.commute mu_properties(1) ord_max) + moreover hence "j-i < n" using ij_Assm i_assms p_fact k_bound n_lst2 by linarith + moreover have "ord p (to_int_mod_ring \) = n" using omega_properties n_lst2 unfolding ord_def + by (metis (no_types) \\ ^ (j-i) = 1\ calculation(3) diff_is_0_eq ij_Assm leD left_right_inverse_power mult.comm_neutral mult_cancel_left mu_properties(1) omega_properties(3) zero_neq_one) + ultimately show False by simp + qed + ultimately have "(1-\^(j-i))* (\l = 0..^(j-i))^l)) = (1-(\^(j-i))^n)" + using geo_sum[of "\ ^ (j-i)" n] by simp + moreover have "(\^(j-i))^n = 1" + by (metis (no_types) left_right_inverse_power mult.commute mult.right_neutral mu_properties(1) omega_properties(1) power_mult power_one) + ultimately have "(\l = 0..^(j-i))^l)) = 0" + by (metis \\ ^ (j - i) \ 1\ eq_iff_diff_eq_0 no_zero_divisors) + thus "(\l = 0.. ^ (l * j) * \ ^ (i * l)) = 0" using 01 02 03 by simp + qed + have " (\j = 0..l = 0.. ^ (l * j) * \ ^ (i * l)) = 0" using jlsi by simp + moreover have " (\j = i..l = 0.. ^ (l * j) * \ ^ (i * l)) = numbers ! i * (of_int_mod_ring n)" using iisj by simp + moreover have " (\j = (i+1)..l = 0.. ^ (l * j) * \ ^ (i * l)) = 0" using ilsj by simp + ultimately have " (\j = 0..l = 0.. ^ (l * j) * \ ^ (i * l)) = + numbers ! i * (of_int_mod_ring n)" using i_assms sum_split + by (smt (z3) add.commute add.left_neutral int_ops(2) less_imp_of_nat_less of_nat_add of_nat_eq_iff of_nat_less_imp_less) + thus "ntt (INTT numbers) i = of_int_mod_ring (int n) * numbers ! i" using 1 2 3 + by (metis mult.commute) + qed + have 2: "\ i. i < n \ (map ((*) (of_int_mod_ring (int n))) numbers ) ! i = (of_int_mod_ring (int n)) * (numbers ! i)" + by (simp add: n_def) + show ?thesis + apply(rule nth_equalityI) + subgoal my_little_subgoal + unfolding INTT_def NTT_def + apply (simp add: n_def) + done + subgoal for i + using 0 1 2 n_def algebra_simps my_little_subgoal length_map + apply auto + done + done +qed + +end +end diff --git a/thys/Number_Theoretic_Transform/Preliminary_Lemmas.thy b/thys/Number_Theoretic_Transform/Preliminary_Lemmas.thy new file mode 100644 --- /dev/null +++ b/thys/Number_Theoretic_Transform/Preliminary_Lemmas.thy @@ -0,0 +1,453 @@ +(* +Title: Preliminary Lemmas for Number Theoretic Transform +Author: Thomas Ammer +*) + +theory Preliminary_Lemmas + imports Berlekamp_Zassenhaus.Finite_Field + "HOL-Number_Theory.Number_Theory" +begin + +section \Preliminary Lemmas\ + +subsection \A little bit of Modular Arithmetic\ + +text \An obvious lemma. Just for simplification.\ + +lemma two_powrs_div: + assumes "j < (i::nat) " + shows "((2^i) div ((2::nat)^(Suc j)))*2 = ((2^i) div (2^j))" +proof- + have "((2::nat)^i) div (2^(Suc j)) = 2^(i -1) div(2^ j)" using assms + by (smt (z3) One_nat_def add_le_cancel_left diff_Suc_Suc div_by_Suc_0 div_if less_nat_zero_code plus_1_eq_Suc power_diff_power_eq zero_neq_numeral) + thus ?thesis + by (metis Suc_diff_Suc Suc_leI assms less_imp_le_nat mult.commute power_Suc power_diff_power_eq zero_neq_numeral) +qed + +lemma two_powr_div: + assumes "j < (i::nat) " + shows "((2^i) div ((2::nat)^j)) = 2^(i-j)" + by (simp add: assms less_or_eq_imp_le power_diff) + + +text \ + The order of an element is the same whether we consider it as an integer or as a natural number. +\ +(* TODO: Move *) +lemma ord_int: "ord (int p) (int x) = ord p x" +proof (cases "coprime p x") + case False + thus ?thesis + by (auto simp: ord_def) +next + case True + have "(LEAST d. 0 < d \ [int x ^ d = 1] (mod int p)) = ord p x" + proof (intro Least_equality conjI) + show "[int x ^ ord p x = 1] (mod int p)" + using True by (metis cong_int_iff of_nat_1 of_nat_power ord_works) + show "ord p x \ y" if "0 < y \ [int x ^ y = 1] (mod int p)" for y + using that by (metis cong_int_iff int_ops(2) linorder_not_less of_nat_power ord_minimal) + qed (use True in auto) + thus ?thesis + by (auto simp: ord_def) +qed + +lemma not_residue_primroot_1: + assumes "n > 2" + shows "\residue_primroot n 1" + using assms totient_gt_1[of n] by (auto simp: residue_primroot_def) + +lemma residue_primroot_not_cong_1: + assumes "residue_primroot n g" "n > 2" + shows "[g \ 1] (mod n)" + using residue_primroot_cong not_residue_primroot_1 assms by metis + + +text \ +We want to show the existence of a generating element of $\mathbb{Z}_p$ where $p$ is prime. +\label{primroot1} \ + +text \Non-trivial order of an element $g$ modulo $p$ in a ring implies $g\neq1$. +Although this lemma applies to all rings, it's only intended to be used in connection with $nat$s or $int$s +\ + +lemma prime_not_2_order_not_1: + assumes "prime p" + "p > 2 " + "ord p g > 2" + shows "g \ 1" +proof + assume "g = 1" + hence "ord p g = 1" unfolding ord_def + by (simp add: Least_equality) + then show False using assms by auto +qed + +text \The same for modular arithmetic.\ + +lemma prime_not_2_order_not_1_mod: + assumes "prime p " + "p > 2 " + "ord p g > 2" + shows "[g \ 1] (mod p)" +proof + assume "[g = 1] (mod p)" + hence "ord p g = 1" unfolding ord_def + by(split if_split, metis assms(1) assms(2) assms(3) ord_cong prime_not_2_order_not_1) + then show False using assms by auto +qed + +text \ +Now we formulate our lemma about generating elements in residue classes: +There is an element $g \in \mathbb{Z}_p$ such that for any $x \in \mathbb{Z}_p$ +there is a natural $i$ such that $g^i \equiv x \; (\mod p)$.\ + +lemma generator_exists: + assumes "prime (p::nat)" "p > 2" + shows "\ g. [g \ 1] (mod p) \ (\ x. (0 x < p )\ (\ i. [g^i = x] (mod p)))" +proof- + obtain g where g_prim_root:"residue_primroot p g" + using assms prime_gt_1_nat prime_primitive_root_exists + by (metis One_nat_def) + have g_not_1: "[g \ 1] (mod p)" + using residue_primroot_not_cong_1 assms g_prim_root by blast + + have "\i. [g ^ i = x] (mod p)" if x_bounds: "x > 0" "x < p" for x + proof - + have 1:"coprime p x" + using assms prime_nat_iff'' x_bounds by blast + have 2:"ord p g = p-1" + by (metis assms(1) g_prim_root residue_primroot_def totient_prime) + hence bij: "bij_betw (\i. g ^ i mod p) {.. totatives p" + by (simp add: "1" coprime_commute in_totatives_iff order_le_less x_bounds) + have " {.. {}" + by (metis assms(1) lessThan_empty_iff prime_nat_iff'' totient_0_iff) + then obtain i where "g^i mod p = x mod p" + using bij_betw_inv[of "(\i. g ^ i mod p)" "{..General Lemmas in a Finite Field\ + +text \ +\label{primroot2} +We make certain assumptions: +From now on, we will calculate in a finite field which is the ring of integers modulo a prime $p$. +Let $n$ be the length of vectors to be transformed. +By Dirichlet's theorem on arithmetic progressions we can + assume that there is a natural number $k$ and a prime $p$ with $p = k\cdot n + 1$. +In order to avoid some special cases and even contradictions, +we additionally assume that $p \geq 3$ and $n \geq 2$. +\ + +text \\label{prelim}\ +locale preliminary = + fixes + a_type::"('a::prime_card) itself" + and p::nat + and n::nat + and k::nat + assumes + p_def: "p= CARD('a)" and p_lst3: "p > 2" and p_fact: "p = k*n +1" + and n_lst2: "n \ 2" +begin + +lemma exp_rule: "((c::('a) mod_ring) * d )^e= (c^e) * (d^e)" + by (simp add: power_mult_distrib) + +lemma "\ y. x \ 0 \ (x::(('a) mod_ring)) * y = 1" + by (metis dvd_field_iff unit_dvdE) + +lemma test: "prime p" + by (simp add: p_def prime_card) + +lemma k_bound: "k > 0" + using p_fact prime_nat_iff'' test by force + +text \We show some homomorphisms.\ + +lemma homomorphism_add: "(of_int_mod_ring x)+(of_int_mod_ring y) = + ((of_int_mod_ring (x+y)) ::(('a::prime_card) mod_ring))" + by (metis of_int_hom.hom_add of_int_of_int_mod_ring) + +lemma homomorphism_mul_on_ring: "(of_int_mod_ring x)*(of_int_mod_ring y) = + ((of_int_mod_ring (x*y)) ::(('a::prime_card) mod_ring))" + by (metis of_int_mult of_int_of_int_mod_ring) + +lemma exp_homo:"(of_int_mod_ring (x^i)) = ((of_int_mod_ring x)^i ::(('a::prime_card) mod_ring))" + by (induction i) (metis of_int_of_int_mod_ring of_int_power)+ + +lemma mod_homo: "((of_int_mod_ring x)::(('a::prime_card) mod_ring)) = of_int_mod_ring (x mod p)" + using p_def unfolding of_int_mod_ring_def by simp + +lemma int_exp_hom: "int x ^i = int (x^i)" + by simp + +lemma coprime_nat_int: "coprime (int p) (to_int_mod_ring pr) \ coprime p (nat(to_int_mod_ring pr))" + unfolding coprime_def to_int_mod_ring_def + by (smt (z3) Rep_mod_ring atLeastLessThan_iff dvd_trans int_dvd_int_iff int_nat_eq int_ops(2) prime_divisor_exists prime_nat_int_transfer primes_dvd_imp_eq test to_int_mod_ring.rep_eq to_int_mod_ring_def) + +lemma nat_int_mod:"[nat (to_int_mod_ring pr) ^ d = 1] (mod p) = + [ (to_int_mod_ring pr) ^ d = 1] (mod (int p)) " + unfolding to_int_mod_ring_def + by (metis Rep_mod_ring atLeastLessThan_iff cong_int_iff int_exp_hom int_nat_eq int_ops(2) to_int_mod_ring.rep_eq to_int_mod_ring_def) + +text \Order of $p$ doesn't change when interpreting it as an integer.\ + +lemma ord_lift: "ord (int p) (to_int_mod_ring pr) = ord p (nat (to_int_mod_ring pr))" +proof - + have "to_int_mod_ring pr = int (nat (to_int_mod_ring pr))" + by (metis Rep_mod_ring atLeastLessThan_iff int_nat_eq to_int_mod_ring.rep_eq) + thus ?thesis + using ord_int by metis +qed + +text \A primitive root has order $p-1$.\ + +lemma primroot_ord: "residue_primroot p g \ ord p g = p -1" + by (simp add: residue_primroot_def test totient_prime) + +text \If $x^l = 1$ in $\mathbb{Z}_p$, then $l$ is an upper bound for the order of $x$ in $\mathbb{Z}_ p$.\ + +lemma ord_max: + assumes "l \ 0" "(x :: (('a::prime_card) mod_ring))^l = 1" + shows " ord p (to_int_mod_ring x) \ l" +proof- + have "[(to_int_mod_ring x)^l = 1] (mod p)" + by (metis assms(2) cong_def exp_homo of_int_mod_ring.rep_eq of_int_mod_ring_to_int_mod_ring one_mod_card_int one_mod_ring.rep_eq p_def) + thus ?thesis unfolding ord_def using assms + by (smt (z3) Least_le less_imp_le_nat not_gr0) +qed + +subsection \Existence of $n$-th Roots of Unity in the Finite Field\ + +text \ +\label{primroot3} +We obtain an element in the finite field such that +its reinterpretation as a $nat$ will be a primitive root in the residue class modulo $p$. +The difference between residue classes, their representatives in the Integers and elements +of the finite field is notable. When conducting informal proofs, this distinction + is usually blurred, but Isabelle enforces the explicit conversion between those structures. +\ + +lemma primroot_ex: + obtains primroot::"('a::prime_card) mod_ring" where + "primroot^(p-1) = 1" + "primroot \ 1" + "residue_primroot p (nat (to_int_mod_ring primroot))" +proof- + obtain g where g_Def: "residue_primroot p g \ g \ 1" + using prime_nat_iff' prime_primitive_root_exists test + by (metis bigger_prime euler_theorem ord_1_right power_one_right prime_nat_iff'' residue_primroot.cases residue_primroot_cong) + hence "[g \ 1] (mod p)" using prime_not_2_order_not_1_mod[of p g] + by (metis One_nat_def p_lst3 less_numeral_extra(4) ord_eq_Suc_0_iff residue_primroot.cases totient_gt_1) + hence "[g^(p-1) = 1] (mod p)" using g_Def + by (metis coprime_commute euler_theorem residue_primroot_def test totient_prime) + moreover hence "int (g ^ (p - 1)) mod int p = (1::int)" + by (metis cong_def int_ops(2) mod_less of_nat_mod prime_gt_1_nat test) + moreover hence "of_int_mod_ring (int (g ^ (p - 1)) mod int p) = + ((of_int_mod_ring 1) ::(('a::prime_card) mod_ring))" by simp + ultimately have "(of_int_mod_ring (g^(p-1))) = (1 ::(('a::prime_card) mod_ring))" + using mod_homo[of "g^(p-1)"] by (metis exp_homo power_0) + hence "((of_int_mod_ring g)^(p-1) ::(('a::prime_card) mod_ring)) = 1" + using exp_homo[of "int g" "p-1"] by simp + moreover + have "((of_int_mod_ring g) ::(('a::prime_card) mod_ring)) \ 1" + proof + assume "((of_int_mod_ring g) ::(('a::prime_card) mod_ring)) = 1" + hence "[int g = 1] (mod p)" using p_def unfolding of_int_mod_ring_def + by (metis \of_int_mod_ring (int g) = 1\ cong_def of_int_mod_ring.rep_eq one_mod_card_int one_mod_ring.rep_eq) + hence "[g=1] (mod p)" + by (metis cong_int_iff int_ops(2)) + thus False + using \[g \ 1] (mod p)\ by auto + qed + thus ?thesis using that g_Def calculation + by (metis Euclidean_Division.pos_mod_bound Euclidean_Division.pos_mod_sign mod_homo nat_int of_nat_0_less_iff of_nat_mod p_def residue_primroot_mod to_int_mod_ring_of_int_mod_ring zero_less_card_finite) +qed + +text \From this, we obtain an $n$-th root of unity $\omega$ in the finite +field of characteristic $p$. + Note that in this step we will use the assumption $p = k \cdot n +1$ +from locale $preliminary$: The $k$-th power of a primitive +root $pr$ modulo $p$ will have the property $(pr^k)^n \equiv 1 \mod p$. +\ + +lemma omega_properties_ex: + obtains \ ::"(('a::prime_card) mod_ring)" + where "\^n = 1" + "\ \ 1" + "\ m. \^m = 1 \ m\0 \ m \ n" +proof- + obtain pr::"(('a::prime_card) mod_ring)" where a: "pr^(p-1) = 1 " and b: "pr \ 1" + and c: "residue_primroot p (nat( to_int_mod_ring pr))" + using primroot_ex by blast + moreover hence "(pr^k)^n =1" + by (simp add: p_fact power_mult) + moreover have "pr^k \ 1" + proof + assume " pr ^ k = 1" + hence "(to_int_mod_ring pr)^k mod p = 1" + by (metis exp_homo of_int_mod_ring.rep_eq of_int_mod_ring_to_int_mod_ring one_mod_ring.rep_eq p_def) + hence "ord p (to_int_mod_ring pr) \ k" + by (simp add: \pr ^ k = 1\ k_bound ord_max) + hence "ord p (nat (to_int_mod_ring pr)) \ k" + by (metis ord_lift) + also have "ord p (nat (to_int_mod_ring pr)) = p - 1" + using c primroot_ord[of "(nat (to_int_mod_ring pr))"] by blast + also have "\ = k * n" + using p_fact by simp + finally have "n \ 1" + using k_bound by simp + thus False + using n_lst2 by linarith + qed + moreover have "\ m. (pr^k)^m = 1 \ m\0 \ m \ n" + proof(rule ccontr) + assume "\ (\m. (pr ^ k) ^ m = 1 \ m\0 \ n \ m) " + then obtain m where "(pr^k)^m = 1 \ m\0 \ m < n" by force + hence "ord p (to_int_mod_ring pr) \ k * m" using ord_max[of "k*m" pr] + by (metis calculation(5) mult_is_0 power_mult) + moreover have "ord p (nat (to_int_mod_ring pr)) = p-1" using c primroot_ord ord_lift by simp + ultimately show False + by (metis \(pr ^ k) ^ m = 1 \ m \ 0 \ m < n\ add_diff_cancel_right' nat_0_less_mult_iff nat_mult_le_cancel_disj not_less ord_lift p_def p_fact prime_card prime_gt_1_nat zero_less_diff) + qed + ultimately show ?thesis + using that by simp +qed + +text \We define an $n$-th root of unity $\omega$ for $NTT$.\ +theorem omega_exists: "\ \ ::(('a::prime_card) mod_ring) . + \^n = 1 \ \ \ 1 \ (\ m. \^m = 1 \ m\0 \ m \ n)" + using omega_properties_ex by metis + +definition "(omega::(('a::prime_card) mod_ring)) = + (SOME \ . (\^n = 1 \ \ \ 1\ (\ m. \^m = 1 \ m\0 \ m \ n)))" + +lemma omega_properties: "omega^n = 1" "omega \ 1" + "(\ m. omega^m = 1 \ m\0 \ m \ n)" + unfolding omega_def using omega_exists + by (smt (verit, best) verit_sko_ex')+ + +text \We define the multiplicative inverse $\mu$ of $\omega$.\ + +definition "mu = omega ^ (n - 1)" + +lemma mu_properties: "mu * omega = 1" "mu \ 1" +proof - + have "omega ^ (n - 1) * omega = omega ^ Suc (n - 1)" + by simp + also have "Suc (n - 1) = n" + using n_lst2 by simp + also have "omega ^ n = 1" + using omega_properties(1) by auto + finally show "mu * omega = 1" + by (simp add: mu_def) +next + show "mu \ 1" + using omega_properties n_lst2 by (auto simp: mu_def) +qed + +subsection \Some Lemmas on Sums\ +text \\label{sums}\ + +text \The following lemmas concern sums over a finite field. + Most of the propositions are intuitive.\ + +lemma sum_in: "(\i=0..<(x::nat). f i * (y ::('a mod_ring))) = (\i=0.. i. i < x \ f i = g i) + \ (\i=0..<(x::nat). f i) = (\i=0..i=0..<(x::nat). (f i)::('a mod_ring)) - (\i=0..i=0..i=0..<(x::nat). \j=0..<(y::nat). f i j) = + (\j=0..<(y::nat). \i=0..<(x::nat). f i j ) " + using Groups_Big.comm_monoid_add_class.sum.swap by fast + +lemma sum_const: "(\i=0..<(x::nat). (c::('a::prime_card) mod_ring)) = (of_int_mod_ring x) * c" + by(induction x, simp add: algebra_simps, simp add: algebra_simps) + (metis distrib_left mult.right_neutral of_int_of_int_mod_ring of_int_of_nat_eq of_nat_Suc) + +lemma sum_split: "(r1::nat) < r2 \ (\l = 0..l = r1..l = 0..l = (a::nat)..< b. f(l+c)) = (\l = (a+c)..< (b+c). f l )" + by(induction a arbitrary: b c) (metis sum.shift_bounds_nat_ivl)+ + +text \One may sum over even and odd indices independently. +The lemma statement was taken from a formalization of FFT~\parencite{FFT-AFP}. +We give an alternative proof adapted to the finite field $\mathbb{Z}_p$. +\ + +lemma sum_splice: + "(\i::nat = 0..<2*nn. f i) = (\i = 0..i = 0..i::nat = 0..<2*(n+1). f i) = (\i::nat = 0..<(2*n). f i) + f(2*n+1) + f (2*n)" + by( simp add: algebra_simps) + also have "\ = (\i::nat = 0..i::nat = 0.. = (\i::nat = 0..<(Suc n). f (2*i)) + (\i::nat = 0..<(Suc n). f (2*i+1))" + by( simp add: algebra_simps) + finally show ?case by simp +qed simp + +lemma sum_even_odd_split: "even (a::nat) \ (\j=0..<(a div 2). f (2*j))+ (\j=0..<(a div 2). f (2*j+1)) = (\j=0..j=(0::nat)..j=0..j=(0::nat)..<2*i. f j )" +by (metis sum_splice) + +lemma sum_neg_in: "- (\j = 0..j = 0..Geometric Sums\ + +text \\label{geosum}\ + +text \This lemma will be important for proving properties on $\mathsf{NTT}$. At first, an informal proof sketch: +\begin{align*} +(1-x) \cdot \sum \limits _{l = 0} ^ {r-1} x^l +&= \sum \limits _{l = 0} ^ {r-1} x^l - x \cdot \sum \limits _{l = 0} ^{r-1} x^l \\ +&= \sum \limits _{l = 0} ^ {r-1} x^l - \sum \limits _{l = 1} ^{r} x^l \\ +& = 1 - x^r +\end{align*} + +The same lemma for integers can be found in~\parencite{Dirichlet_Series-AFP}. + Our version is adapted to finite fields. +\ + +lemma geo_sum: + assumes "x \ 1" + shows "(1-x)*(\l = 0..l = 0..l = 0.. l. x^l" x r] + by(simp add: algebra_simps) + have 1:"(\l = 0..l = 0..l = 0..l = 0..