diff --git a/src/Pure/System/scala.scala b/src/Pure/System/scala.scala --- a/src/Pure/System/scala.scala +++ b/src/Pure/System/scala.scala @@ -1,252 +1,262 @@ /* Title: Pure/System/scala.scala Author: Makarius Support for Scala at runtime. */ package isabelle import java.io.{File => JFile, StringWriter, PrintWriter} import scala.tools.nsc.{GenericRunnerSettings, ConsoleWriter, NewLinePrintWriter} import scala.tools.nsc.interpreter.{IMain, Results} import scala.tools.nsc.interpreter.shell.ReplReporterImpl object Scala { /** registered functions **/ abstract class Fun(val name: String, val thread: Boolean = false) { override def toString: String = name def multi: Boolean = true def position: Properties.T = here.position def here: Scala_Project.Here def invoke(args: List[Bytes]): List[Bytes] } abstract class Fun_Strings(name: String, thread: Boolean = false) extends Fun(name, thread = thread) { override def invoke(args: List[Bytes]): List[Bytes] = apply(args.map(_.text)).map(Bytes.apply) def apply(args: List[String]): List[String] } abstract class Fun_String(name: String, thread: Boolean = false) extends Fun_Strings(name, thread = thread) { override def multi: Boolean = false override def apply(args: List[String]): List[String] = List(apply(Library.the_single(args))) def apply(arg: String): String } class Functions(val functions: Fun*) extends Isabelle_System.Service lazy val functions: List[Fun] = Isabelle_System.make_services(classOf[Functions]).flatMap(_.functions) /** demo functions **/ object Echo extends Fun_String("echo") { val here = Scala_Project.here def apply(arg: String): String = arg } object Sleep extends Fun_String("sleep") { val here = Scala_Project.here def apply(seconds: String): String = { val t = seconds match { case Value.Double(s) => Time.seconds(s) case _ => error("Malformed argument: " + quote(seconds)) } val t0 = Time.now() t.sleep() val t1 = Time.now() (t1 - t0).toString } } /** compiler **/ def class_path(): List[String] = for { prop <- List("isabelle.scala.classpath", "java.class.path") elems = System.getProperty(prop, "") if elems.nonEmpty elem <- space_explode(JFile.pathSeparatorChar, elems) if elem.nonEmpty } yield elem object Compiler { def context( error: String => Unit = Exn.error, jar_dirs: List[JFile] = Nil ): Context = { def find_jars(dir: JFile): List[String] = File.find_files(dir, file => file.getName.endsWith(".jar")). map(File.absolute_name) val settings = new GenericRunnerSettings(error) settings.classpath.value = (class_path() ::: jar_dirs.flatMap(find_jars)).mkString(JFile.pathSeparator) new Context(settings) } def default_print_writer: PrintWriter = new NewLinePrintWriter(new ConsoleWriter, true) class Context private [Compiler](val settings: GenericRunnerSettings) { override def toString: String = settings.toString def interpreter( print_writer: PrintWriter = default_print_writer, class_loader: ClassLoader = null ): IMain = { new IMain(settings, new ReplReporterImpl(settings, print_writer)) { override def parentClassLoader: ClassLoader = if (class_loader == null) super.parentClassLoader else class_loader } } - def toplevel(source: String): List[String] = { + def toplevel(interpret: Boolean, source: String): List[String] = { val out = new StringWriter val interp = interpreter(new PrintWriter(out)) val marker = '\u000b' - val ok = interp.withLabel(marker.toString) { (new interp.ReadEvalPrint).compile(source) } + val ok = + interp.withLabel(marker.toString) { + if (interpret) interp.interpret(source) == Results.Success + else (new interp.ReadEvalPrint).compile(source) + } out.close() val Error = """(?s)^\S* error: (.*)$""".r val errors = space_explode(marker, Library.strip_ansi_color(out.toString)). collect({ case Error(msg) => "Scala error: " + Library.trim_line(msg) }) if (!ok && errors.isEmpty) List("Error") else errors } } } object Toplevel extends Fun_String("scala_toplevel") { val here = Scala_Project.here - def apply(source: String): String = { + def apply(arg: String): String = { + val (interpret, source) = + YXML.parse_body(arg) match { + case Nil => (false, "") + case List(XML.Text(source)) => (false, source) + case body => import XML.Decode._; pair(bool, string)(body) + } val errors = - try { Compiler.context().toplevel(source) } + try { Compiler.context().toplevel(interpret, source) } catch { case ERROR(msg) => List(msg) } locally { import XML.Encode._; YXML.string_of_body(list(string)(errors)) } } } /** invoke Scala functions from ML **/ /* invoke function */ object Tag extends Enumeration { val NULL, OK, ERROR, FAIL, INTERRUPT = Value } def function_thread(name: String): Boolean = functions.find(fun => fun.name == name) match { case Some(fun) => fun.thread case None => false } def function_body(name: String, args: List[Bytes]): (Tag.Value, List[Bytes]) = functions.find(fun => fun.name == name) match { case Some(fun) => Exn.capture { fun.invoke(args) } match { case Exn.Res(null) => (Tag.NULL, Nil) case Exn.Res(res) => (Tag.OK, res) case Exn.Exn(Exn.Interrupt()) => (Tag.INTERRUPT, Nil) case Exn.Exn(e) => (Tag.ERROR, List(Bytes(Exn.message(e)))) } case None => (Tag.FAIL, List(Bytes("Unknown Isabelle/Scala function: " + quote(name)))) } /* protocol handler */ class Handler extends Session.Protocol_Handler { private var session: Session = null private var futures = Map.empty[String, Future[Unit]] override def init(session: Session): Unit = synchronized { this.session = session } override def exit(): Unit = synchronized { for ((id, future) <- futures) cancel(id, future) futures = Map.empty } private def result(id: String, tag: Scala.Tag.Value, res: List[Bytes]): Unit = synchronized { if (futures.isDefinedAt(id)) { session.protocol_command_raw("Scala.result", Bytes(id) :: Bytes(tag.id.toString) :: res) futures -= id } } private def cancel(id: String, future: Future[Unit]): Unit = { future.cancel() result(id, Scala.Tag.INTERRUPT, Nil) } private def invoke_scala(msg: Prover.Protocol_Output): Boolean = synchronized { msg.properties match { case Markup.Invoke_Scala(name, id) => def body: Unit = { val (tag, res) = Scala.function_body(name, msg.chunks) result(id, tag, res) } val future = if (Scala.function_thread(name)) { Future.thread(name = Isabelle_Thread.make_name(base = "invoke_scala"))(body) } else Future.fork(body) futures += (id -> future) true case _ => false } } private def cancel_scala(msg: Prover.Protocol_Output): Boolean = synchronized { msg.properties match { case Markup.Cancel_Scala(id) => futures.get(id) match { case Some(future) => cancel(id, future) case None => } true case _ => false } } override val functions = List( Markup.Invoke_Scala.name -> invoke_scala, Markup.Cancel_Scala.name -> cancel_scala) } } class Scala_Functions extends Scala.Functions( Scala.Echo, Scala.Sleep, Scala.Toplevel, Bytes.Decode_Base64, Bytes.Encode_Base64, Doc.Doc_Names, Bibtex.Check_Database, Isabelle_System.Make_Directory, Isabelle_System.Copy_Dir, Isabelle_System.Copy_File, Isabelle_System.Copy_File_Base, Isabelle_System.Rm_Tree, Isabelle_System.Download, Isabelle_System.Isabelle_Id, Isabelle_Tool.Isabelle_Tools, isabelle.atp.SystemOnTPTP.List_Systems, isabelle.atp.SystemOnTPTP.Run_System) diff --git a/src/Pure/System/scala_compiler.ML b/src/Pure/System/scala_compiler.ML --- a/src/Pure/System/scala_compiler.ML +++ b/src/Pure/System/scala_compiler.ML @@ -1,95 +1,98 @@ (* Title: Pure/System/scala_compiler.ML Author: Makarius Scala compiler operations. *) signature SCALA_COMPILER = sig - val toplevel: string -> unit + val toplevel: bool -> string -> unit val static_check: string * Position.T -> unit end; structure Scala_Compiler: SCALA_COMPILER = struct (* check declaration *) -fun toplevel source = +fun toplevel interpret source = let val errors = - \<^scala>\scala_toplevel\ source + (interpret, source) + |> let open XML.Encode in pair bool string end + |> YXML.string_of_body + |> \<^scala>\scala_toplevel\ |> YXML.parse_body |> let open XML.Decode in list string end in if null errors then () else error (cat_lines errors) end; fun static_check (source, pos) = - toplevel ("package test\nclass __Dummy__ { __dummy__ => " ^ source ^ " }") + toplevel false ("package test\nclass __Dummy__ { __dummy__ => " ^ source ^ " }") handle ERROR msg => error (msg ^ Position.here pos); (* antiquotations *) local fun make_list bg en = space_implode "," #> enclose bg en; fun print_args [] = "" | print_args xs = make_list "(" ")" xs; fun print_types [] = "" | print_types Ts = make_list "[" "]" Ts; fun print_class (c, Ts) = c ^ print_types Ts; val types = Scan.optional (Parse.$$$ "[" |-- Parse.list1 Parse.name --| Parse.$$$ "]") []; val class = Scan.option (Parse.$$$ "(" |-- Parse.!!! (Parse.$$$ "in" |-- Parse.name -- types --| Parse.$$$ ")")); val arguments = (Parse.nat >> (fn n => replicate n "_") || Parse.list (Parse.underscore || Parse.name >> (fn T => "_ : " ^ T))) >> print_args; val args = Scan.optional (Parse.$$$ "(" |-- arguments --| Parse.$$$ ")") " _"; fun scala_name name = Latex.macro "isatt" (Latex.string (Latex.output_ascii_breakable "." name)); in val _ = Theory.setup (Document_Output.antiquotation_verbatim_embedded \<^binding>\scala\ (Scan.lift Parse.embedded_position) (fn _ => fn (s, pos) => (static_check (s, pos); s)) #> Document_Output.antiquotation_raw_embedded \<^binding>\scala_type\ (Scan.lift (Parse.embedded_position -- (types >> print_types))) (fn _ => fn ((t, pos), type_args) => (static_check ("type _Test_" ^ type_args ^ " = " ^ t ^ type_args, pos); scala_name (t ^ type_args))) #> Document_Output.antiquotation_raw_embedded \<^binding>\scala_object\ (Scan.lift Parse.embedded_position) (fn _ => fn (x, pos) => (static_check ("val _test_ = " ^ x, pos); scala_name x)) #> Document_Output.antiquotation_raw_embedded \<^binding>\scala_method\ (Scan.lift (class -- Parse.embedded_position -- types -- args)) (fn _ => fn (((class_context, (method, pos)), method_types), method_args) => let val class_types = (case class_context of SOME (_, Ts) => Ts | NONE => []); val def = "def _test_" ^ print_types (merge (op =) (method_types, class_types)); val def_context = (case class_context of NONE => def ^ " = " | SOME c => def ^ "(_this_ : " ^ print_class c ^ ") = _this_."); val source = def_context ^ method ^ method_args; val _ = static_check (source, pos); val text = (case class_context of NONE => method | SOME c => print_class c ^ "." ^ method); in scala_name text end)); end; end;