diff --git a/thys/Certification_Monads/Error_Monad.thy b/thys/Certification_Monads/Error_Monad.thy --- a/thys/Certification_Monads/Error_Monad.thy +++ b/thys/Certification_Monads/Error_Monad.thy @@ -1,323 +1,330 @@ (* Title: Error_Monad Author: Christian Sternagel Author: René Thiemann *) section \The Sum Type as Error Monad\ theory Error_Monad imports "HOL-Library.Monad_Syntax" Error_Syntax begin text \Make monad syntax (including do-notation) available for the sum type.\ definition bind :: "'e + 'a \ ('a \ 'e + 'b) \ 'e + 'b" where "bind m f = (case m of Inr x \ f x | Inl e \ Inl e)" adhoc_overloading Monad_Syntax.bind bind abbreviation (input) "return \ Inr" abbreviation (input) "error \ Inl" abbreviation (input) "run \ projr" subsection \Monad Laws\ lemma return_bind [simp]: "(return x \ f) = f x" by (simp add: bind_def) lemma bind_return [simp]: "(m \ return) = m" by (cases m) (simp_all add: bind_def) lemma error_bind [simp]: "(error e \ f) = error e" by (simp add: bind_def) lemma bind_assoc [simp]: fixes m :: "'a + 'b" shows "((m \ f) \ g) = (m \ (\x. f x \ g))" by (cases m) (simp_all add: bind_def) lemma bind_cong [fundef_cong]: fixes m1 m2 :: "'e + 'a" and f1 f2 :: "'a \ 'e + 'b" assumes "m1 = m2" and "\y. m2 = Inr y \ f1 y = f2 y" shows "(m1 \ f1) = (m2 \ f2)" using assms by (cases "m1") (auto simp: bind_def) definition catch_error :: "'e + 'a \ ('e \ 'f + 'a) \ 'f + 'a" where catch_def: "catch_error m f = (case m of Inl e \ f e | Inr x \ Inr x)" adhoc_overloading Error_Syntax.catch catch_error lemma catch_splits: "P (try m catch f) \ (\e. m = Inl e \ P (f e)) \ (\x. m = Inr x \ P (Inr x))" "P (try m catch f) \ (\ ((\e. m = Inl e \ \ P (f e)) \ (\x. m = Inr x \ \ P (Inr x))))" by (case_tac [!] m) (simp_all add: catch_def) abbreviation update_error :: "'e + 'a \ ('e \ 'f) \ 'f + 'a" where "update_error m f \ try m catch (\x. error (f x))" adhoc_overloading Error_Syntax.update_error update_error lemma catch_return [simp]: "(try return x catch f) = return x" by (simp add: catch_def) lemma catch_error [simp]: "(try error e catch f) = f e" by (simp add: catch_def) lemma update_error_return [simp]: "(m <+? c = return x) \ (m = return x)" by (cases m) simp_all definition "isOK m \ (case m of Inl e \ False | Inr x \ True)" lemma isOK_E [elim]: assumes "isOK m" obtains x where "m = return x" using assms by (cases m) (simp_all add: isOK_def) lemma isOK_I [simp, intro]: "m = return x \ isOK m" by (cases m) (simp_all add: isOK_def) lemma isOK_iff: "isOK m \ (\x. m = return x)" by blast lemma isOK_error [simp]: "isOK (error x) = False" by blast lemma isOK_bind [simp]: "isOK (m \ f) \ isOK m \ isOK (f (run m))" by (cases m) simp_all lemma isOK_update_error [simp]: "isOK (m <+? f) \ isOK m" by (cases m) simp_all lemma isOK_case_prod [simp]: "isOK (case lr of (l, r) \ P l r) = (case lr of (l, r) \ isOK (P l r))" by (rule prod.case_distrib) lemma isOK_case_option [simp]: "isOK (case x of None \ P | Some v \ Q v) = (case x of None \ isOK P | Some v \ isOK (Q v))" by (cases x) (auto) lemma isOK_Let [simp]: "isOK (Let s f) = isOK (f s)" by (simp add: Let_def) lemma run_bind [simp]: "isOK m \ run (m \ f) = run (f (run m))" by auto lemma run_catch [simp]: "isOK m \ run (try m catch f) = run m" by auto fun foldM :: "('a \ 'b \ 'e + 'a) \ 'a \ 'b list \ 'e + 'a" where "foldM f d [] = return d" | "foldM f d (x # xs) = do { y \ f d x; foldM f y xs }" fun forallM_index_aux :: "('a \ nat \ 'e + unit) \ nat \ 'a list \ (('a \ nat) \ 'e) + unit" where "forallM_index_aux P i [] = return ()" | "forallM_index_aux P i (x # xs) = do { P x i <+? Pair (x, i); forallM_index_aux P (Suc i) xs }" lemma isOK_forallM_index_aux [simp]: "isOK (forallM_index_aux P n xs) = (\i < length xs. isOK (P (xs ! i) (i + n)))" proof (induct xs arbitrary: n) case (Cons x xs) have "(\i < length (x # xs). isOK (P ((x # xs) ! i) (i + n))) \ (isOK (P x n) \ (\i < length xs. isOK (P (xs ! i) (i + Suc n))))" by (auto, case_tac i) (simp_all) then show ?case unfolding Cons [of "Suc n", symmetric] by simp qed auto definition forallM_index :: "('a \ nat \ 'e + unit) \ 'a list \ (('a \ nat) \ 'e) + unit" where "forallM_index P xs = forallM_index_aux P 0 xs" lemma isOK_forallM_index [simp]: "isOK (forallM_index P xs) \ (\i < length xs. isOK (P (xs ! i) i))" unfolding forallM_index_def isOK_forallM_index_aux by simp lemma forallM_index [fundef_cong]: fixes c :: "'a \ nat \ 'e + unit" assumes "\x i. x \ set xs \ c x i = d x i" shows "forallM_index c xs = forallM_index d xs" proof - { fix n have "forallM_index_aux c n xs = forallM_index_aux d n xs" using assms by (induct xs arbitrary: n) simp_all } then show ?thesis by (simp add: forallM_index_def) qed hide_const forallM_index_aux text \ Check whether @{term f} succeeds for all elements of a given list. In case it doesn't, return the first offending element together with the produced error. \ fun forallM :: "('a \ 'e + unit) \ 'a list \ ('a * 'e) + unit" where "forallM f [] = return ()" | "forallM f (x # xs) = f x <+? Pair x \ forallM f xs" lemma forallM_fundef_cong [fundef_cong]: assumes "xs = ys" "\x. x \ set ys \ f x = g x" shows "forallM f xs = forallM g ys" unfolding assms(1) using assms(2) proof (induct ys) case (Cons x xs) thus ?case by (cases "g x", auto) qed auto lemma isOK_forallM [simp]: "isOK (forallM f xs) \ (\x \ set xs. isOK (f x))" by (induct xs) (simp_all) text \ Check whether @{term f} succeeds for at least one element of a given list. In case it doesn't, return the list of produced errors. \ fun existsM :: "('a \ 'e + unit) \ 'a list \ 'e list + unit" where "existsM f [] = error []" | "existsM f (x # xs) = (try f x catch (\e. existsM f xs <+? Cons e))" +lemma existsM_cong [fundef_cong]: + assumes "xs = ys" + and "\x. x \ set ys \ f x = g x" + shows "existsM f xs = existsM g ys" + using assms + by (induct ys arbitrary:xs) (auto split:catch_splits) + lemma isOK_existsM [simp]: "isOK (existsM f xs) \ (\x\set xs. isOK (f x))" proof (induct xs) case (Cons x xs) show ?case proof (cases "f x") case (Inl e) with Cons show ?thesis by simp qed (auto simp add: catch_def) qed simp lemma is_OK_if_return [simp]: "isOK (if b then return x else m) \ b \ isOK m" "isOK (if b then m else return x) \ \ b \ isOK m" by simp_all lemma isOK_if_error [simp]: "isOK (if b then error e else m) \ \ b \ isOK m" "isOK (if b then m else error e) \ b \ isOK m" by simp_all lemma isOK_if: "isOK (if b then x else y) \ b \ isOK x \ \ b \ isOK y" by simp fun sequence :: "('e + 'a) list \ 'e + 'a list" where "sequence [] = Inr []" | "sequence (m # ms) = do { x \ m; xs \ sequence ms; return (x # xs) }" subsection \Monadic Map for Error Monad\ fun mapM :: "('a \ 'e + 'b) \ 'a list \ 'e + 'b list" where "mapM f [] = return []" | "mapM f (x#xs) = do { y \ f x; ys \ mapM f xs; Inr (y # ys) }" lemma mapM_error: "(\e. mapM f xs = error e) \ (\x\set xs. \e. f x = error e)" proof (induct xs) case (Cons x xs) then show ?case by (cases "f x", simp_all, cases "mapM f xs", simp_all) qed simp lemma mapM_return: assumes "mapM f xs = return ys" shows "ys = map (run \ f) xs \ (\x\set xs. \e. f x \ error e)" using assms proof (induct xs arbitrary: ys) case (Cons x xs ys) then show ?case by (cases "f x", simp, cases "mapM f xs", simp_all) qed simp lemma mapM_return_idx: assumes *: "mapM f xs = Inr ys" and "i < length xs" shows "\y. f (xs ! i) = Inr y \ ys ! i = y" proof - note ** = mapM_return [OF *, unfolded set_conv_nth] with assms have "\e. f (xs ! i) \ Inl e" by auto then obtain y where "f (xs ! i) = Inr y" by (cases "f (xs ! i)") auto then have "f (xs ! i) = Inr y \ ys ! i = y" unfolding ** [THEN conjunct1] using assms by auto then show ?thesis .. qed lemma mapM_cong [fundef_cong]: assumes "xs = ys" and "\x. x \ set ys \ f x = g x" shows "mapM f xs = mapM g ys" unfolding assms(1) using assms(2) by (induct ys) auto lemma bindE [elim]: assumes "(p \ f) = return x" obtains y where "p = return y" and "f y = return x" using assms by (cases p) simp_all lemma then_return_eq [simp]: "(p \ q) = return f \ isOK p \ q = return f" by (cases p) simp_all fun choice :: "('e + 'a) list \ 'e list + 'a" where "choice [] = error []" | "choice (x # xs) = (try x catch (\e. choice xs <+? Cons e))" declare choice.simps [simp del] lemma isOK_mapM: assumes "isOK (mapM f xs)" shows "(\x. x \ set xs \ isOK (f x)) \ run (mapM f xs) = map (\x. run (f x)) xs" using assms mapM_return[of f xs] by (force simp: isOK_def split: sum.splits)+ fun firstM where "firstM f [] = error []" | "firstM f (x # xs) = (try f x \ return x catch (\e. firstM f xs <+? Cons e))" lemma firstM: "isOK (firstM f xs) \ (\x\set xs. isOK (f x))" by (induct xs) (auto simp: catch_def split: sum.splits) lemma firstM_return: assumes "firstM f xs = return y" shows "isOK (f y) \ y \ set xs" using assms by (induct xs) (auto simp: catch_def split: sum.splits) end