diff --git a/thys/CRYSTALS-Kyber/Crypto_Scheme_NTT.thy b/thys/CRYSTALS-Kyber/Crypto_Scheme_NTT.thy --- a/thys/CRYSTALS-Kyber/Crypto_Scheme_NTT.thy +++ b/thys/CRYSTALS-Kyber/Crypto_Scheme_NTT.thy @@ -1,95 +1,99 @@ theory Crypto_Scheme_NTT imports Crypto_Scheme NTT_Scheme begin section \Kyber Algorithm using NTT for Fast Multiplication\ +hide_type Matrix.vec + context kyber_ntt begin + definition mult_ntt:: "'a qr \ 'a qr \ 'a qr" (infixl "*\<^bsub>ntt\<^esub>" 70) where "mult_ntt f g = inv_ntt_poly (ntt_poly f \ ntt_poly g)" lemma mult_ntt: "f*g = f *\<^bsub>ntt\<^esub> g" unfolding mult_ntt_def using convolution_thm_ntt_poly by auto -definition scalar_prod_ntt:: "('a qr, 'k) vec \ ('a qr, 'k) vec \ 'a qr" (infixl "\\<^bsub>ntt\<^esub>" 70) where +definition scalar_prod_ntt:: + "('a qr, 'k) vec \ ('a qr, 'k) vec \ 'a qr" (infixl "\\<^bsub>ntt\<^esub>" 70) where "scalar_prod_ntt v w = (\i\(UNIV::'k set). (vec_nth v i) *\<^bsub>ntt\<^esub> (vec_nth w i))" lemma scalar_prod_ntt: "scalar_product v w = scalar_prod_ntt v w" unfolding scalar_product_def scalar_prod_ntt_def using mult_ntt by auto definition mat_vec_mult_ntt:: "(('a qr, 'k) vec, 'k) vec \ ('a qr, 'k) vec \ ('a qr, 'k) vec" (infixl "\\<^bsub>ntt\<^esub>" 70) where "mat_vec_mult_ntt A v = vec_lambda (\i. (\j\UNIV. (vec_nth (vec_nth A i) j) *\<^bsub>ntt\<^esub> (vec_nth v j)))" lemma mat_vec_mult_ntt: "A *v v = mat_vec_mult_ntt A v" unfolding matrix_vector_mult_def mat_vec_mult_ntt_def using mult_ntt by auto text \Refined algorithm using NTT for multiplications\ definition key_gen_ntt :: "nat \ (('a qr, 'k) vec, 'k) vec \ ('a qr, 'k) vec \ ('a qr, 'k) vec \ ('a qr, 'k) vec" where "key_gen_ntt dt A s e = compress_vec dt (A \\<^bsub>ntt\<^esub> s + e)" lemma key_gen_ntt: "key_gen_ntt dt A s e = key_gen dt A s e" unfolding key_gen_ntt_def key_gen_def mat_vec_mult_ntt by auto definition encrypt_ntt :: "('a qr, 'k) vec \ (('a qr, 'k) vec, 'k) vec \ ('a qr, 'k) vec \ ('a qr, 'k) vec \ ('a qr) \ nat \ nat \ nat \ 'a qr \ (('a qr, 'k) vec) * ('a qr)" where "encrypt_ntt t A r e1 e2 dt du dv m = (compress_vec du ((transpose A) \\<^bsub>ntt\<^esub> r + e1), compress_poly dv ((decompress_vec dt t) \\<^bsub>ntt\<^esub> r + e2 + to_module (round((real_of_int q)/2)) *\<^bsub>ntt\<^esub> m)) " lemma encrypt_ntt: "encrypt_ntt t A r e1 e2 dt du dv m = encrypt t A r e1 e2 dt du dv m" unfolding encrypt_ntt_def encrypt_def mat_vec_mult_ntt scalar_prod_ntt mult_ntt by auto definition decrypt_ntt :: "('a qr, 'k) vec \ ('a qr) \ ('a qr, 'k) vec \ nat \ nat \ 'a qr" where "decrypt_ntt u v s du dv = compress_poly 1 ((decompress_poly dv v) - s \\<^bsub>ntt\<^esub> (decompress_vec du u))" lemma decrypt_ntt: "decrypt_ntt u v s du dv = decrypt u v s du dv" unfolding decrypt_ntt_def decrypt_def scalar_prod_ntt by auto text \$(1-\delta)$-correctness for the refined algorithm\ lemma kyber_correct_ntt: fixes A s r e e1 e2 dt du dv ct cu cv t u v assumes t_def: "t = key_gen_ntt dt A s e" and u_v_def: "(u,v) = encrypt_ntt t A r e1 e2 dt du dv m" and ct_def: "ct = compress_error_vec dt (A \\<^bsub>ntt\<^esub> s + e)" and cu_def: "cu = compress_error_vec du ((transpose A) \\<^bsub>ntt\<^esub> r + e1)" and cv_def: "cv = compress_error_poly dv ((decompress_vec dt t) \\<^bsub>ntt\<^esub> r + e2 + to_module (round((real_of_int q)/2)) *\<^bsub>ntt\<^esub> m)" and delta: "abs_infty_poly (e \\<^bsub>ntt\<^esub> r + e2 + cv - s \\<^bsub>ntt\<^esub> e1 + ct \\<^bsub>ntt\<^esub> r - s \\<^bsub>ntt\<^esub> cu) < round (real_of_int q / 4)" and m01: "set ((coeffs \ of_qr) m) \ {0,1}" shows "decrypt_ntt u v s du dv = m" using assms unfolding key_gen_ntt encrypt_ntt decrypt_ntt mat_vec_mult_ntt[symmetric] scalar_prod_ntt[symmetric] mult_ntt[symmetric] using kyber_correct by auto end end \ No newline at end of file