diff --git a/src/Pure/System/scala.scala b/src/Pure/System/scala.scala --- a/src/Pure/System/scala.scala +++ b/src/Pure/System/scala.scala @@ -1,225 +1,234 @@ /* Title: Pure/System/scala.scala Author: Makarius Support for Scala at runtime. */ package isabelle import java.io.{File => JFile, StringWriter, PrintWriter} import scala.tools.nsc.{GenericRunnerSettings, ConsoleWriter, NewLinePrintWriter} -import scala.tools.nsc.interpreter.IMain +import scala.tools.nsc.interpreter.{IMain, Results} object Scala { /** registered functions **/ abstract class Fun(val name: String) extends Function[String, String] { override def toString: String = name def apply(arg: String): String } class Functions(val functions: Fun*) extends Isabelle_System.Service lazy val functions: List[Fun] = Isabelle_System.make_services(classOf[Functions]).flatMap(_.functions) /** demo functions **/ object Echo extends Fun("echo") { def apply(arg: String): String = arg } object Sleep extends Fun("sleep") { def apply(seconds: String): String = { val t = seconds match { case Value.Double(s) => Time.seconds(s) case _ => error("Malformed argument: " + quote(seconds)) } val t0 = Time.now() t.sleep val t1 = Time.now() (t1 - t0).toString } } /** compiler **/ object Compiler { def context( error: String => Unit = Exn.error, jar_dirs: List[JFile] = Nil): Context = { def find_jars(dir: JFile): List[String] = File.find_files(dir, file => file.getName.endsWith(".jar")). map(File.absolute_name) val class_path = for { prop <- List("isabelle.scala.classpath", "java.class.path") path = System.getProperty(prop, "") if path != "\"\"" elem <- space_explode(JFile.pathSeparatorChar, path) } yield elem val settings = new GenericRunnerSettings(error) settings.classpath.value = (class_path ::: jar_dirs.flatMap(find_jars)).mkString(JFile.pathSeparator) new Context(settings) } def default_print_writer: PrintWriter = new NewLinePrintWriter(new ConsoleWriter, true) class Context private [Compiler](val settings: GenericRunnerSettings) { override def toString: String = settings.toString def interpreter( print_writer: PrintWriter = default_print_writer, class_loader: ClassLoader = null): IMain = { new IMain(settings, print_writer) { override def parentClassLoader: ClassLoader = if (class_loader == null) super.parentClassLoader else class_loader } } - def toplevel(source: String): List[String] = + def toplevel(interpret: Boolean, source: String): List[String] = { val out = new StringWriter val interp = interpreter(new PrintWriter(out)) - val rep = new interp.ReadEvalPrint - val ok = interp.withLabel("\u0001") { rep.compile(source) } + val ok = + interp.withLabel("\u0001") { + if (interpret) interp.interpret(source) == Results.Success + else (new interp.ReadEvalPrint).compile(source) + } out.close val Error = """(?s)^\S* error: (.*)$""".r val errors = space_explode('\u0001', Library.strip_ansi_color(out.toString)). collect({ case Error(msg) => "Scala error: " + Library.trim_line(msg) }) if (!ok && errors.isEmpty) List("Error") else errors } } } object Toplevel extends Fun("scala_toplevel") { - def apply(source: String): String = + def apply(arg: String): String = { + val (interpret, source) = + YXML.parse_body(arg) match { + case Nil => (false, "") + case List(XML.Text(source)) => (false, source) + case body => import XML.Decode._; pair(bool, string)(body) + } val errors = - try { Compiler.context().toplevel(source) } + try { Compiler.context().toplevel(interpret, source) } catch { case ERROR(msg) => List(msg) } locally { import XML.Encode._; YXML.string_of_body(list(string)(errors)) } } } /** invoke Scala functions from ML **/ /* invoke function */ object Tag extends Enumeration { val NULL, OK, ERROR, FAIL, INTERRUPT = Value } def function(name: String, arg: String): (Tag.Value, String) = functions.find(fun => fun.name == name) match { case Some(fun) => Exn.capture { fun(arg) } match { case Exn.Res(null) => (Tag.NULL, "") case Exn.Res(res) => (Tag.OK, res) case Exn.Exn(Exn.Interrupt()) => (Tag.INTERRUPT, "") case Exn.Exn(e) => (Tag.ERROR, Exn.message(e)) } case None => (Tag.FAIL, "Unknown Isabelle/Scala function: " + quote(name)) } /* protocol handler */ class Handler extends Session.Protocol_Handler { private var session: Session = null private var futures = Map.empty[String, Future[Unit]] override def init(session: Session): Unit = synchronized { this.session = session } override def exit(): Unit = synchronized { for ((id, future) <- futures) cancel(id, future) futures = Map.empty } private def result(id: String, tag: Scala.Tag.Value, res: String): Unit = synchronized { if (futures.isDefinedAt(id)) { session.protocol_command("Scala.result", id, tag.id.toString, res) futures -= id } } private def cancel(id: String, future: Future[Unit]) { future.cancel result(id, Scala.Tag.INTERRUPT, "") } private def invoke_scala(msg: Prover.Protocol_Output): Boolean = synchronized { msg.properties match { case Markup.Invoke_Scala(name, id) => futures += (id -> Future.fork { val (tag, res) = Scala.function(name, msg.text) result(id, tag, res) }) true case _ => false } } private def cancel_scala(msg: Prover.Protocol_Output): Boolean = synchronized { msg.properties match { case Markup.Cancel_Scala(id) => futures.get(id) match { case Some(future) => cancel(id, future) case None => } true case _ => false } } override val functions = List( Markup.Invoke_Scala.name -> invoke_scala, Markup.Cancel_Scala.name -> cancel_scala) } } class Scala_Functions extends Scala.Functions( Scala.Echo, Scala.Sleep, Scala.Toplevel, Bibtex.Check_Database) diff --git a/src/Pure/System/scala_compiler.ML b/src/Pure/System/scala_compiler.ML --- a/src/Pure/System/scala_compiler.ML +++ b/src/Pure/System/scala_compiler.ML @@ -1,96 +1,99 @@ (* Title: Pure/System/scala_compiler.ML Author: Makarius Scala compiler operations. *) signature SCALA_COMPILER = sig - val toplevel: string -> unit + val toplevel: bool -> string -> unit val static_check: string * Position.T -> unit end; structure Scala_Compiler: SCALA_COMPILER = struct (* check declaration *) -fun toplevel source = +fun toplevel interpret source = let val errors = - \<^scala>\scala_toplevel\ source + (interpret, source) + |> let open XML.Encode in pair bool string end + |> YXML.string_of_body + |> \<^scala>\scala_toplevel\ |> YXML.parse_body |> let open XML.Decode in list string end in if null errors then () else error (cat_lines errors) end; fun static_check (source, pos) = - toplevel ("package test\nclass __Dummy__ { __dummy__ => " ^ source ^ " }") + toplevel false ("package test\nclass __Dummy__ { __dummy__ => " ^ source ^ " }") handle ERROR msg => error (msg ^ Position.here pos); (* antiquotations *) local fun make_list bg en = space_implode "," #> enclose bg en; fun print_args [] = "" | print_args xs = make_list "(" ")" xs; fun print_types [] = "" | print_types Ts = make_list "[" "]" Ts; fun print_class (c, Ts) = c ^ print_types Ts; val types = Scan.optional (Parse.$$$ "[" |-- Parse.list1 Parse.name --| Parse.$$$ "]") []; val class = Scan.option (Parse.$$$ "(" |-- Parse.!!! (Parse.$$$ "in" |-- Parse.name -- types --| Parse.$$$ ")")); val arguments = (Parse.nat >> (fn n => replicate n "_") || Parse.list (Parse.underscore || Parse.name >> (fn T => "_ : " ^ T))) >> print_args; val args = Scan.optional (Parse.$$$ "(" |-- arguments --| Parse.$$$ ")") " _"; fun scala_name name = let val latex = Latex.string (Latex.output_ascii_breakable "." name) in Latex.enclose_block "\\isatt{" "}" [latex] end; in val _ = Theory.setup (Thy_Output.antiquotation_verbatim_embedded \<^binding>\scala\ (Scan.lift Args.embedded_position) (fn _ => fn (s, pos) => (static_check (s, pos); s)) #> Thy_Output.antiquotation_raw_embedded \<^binding>\scala_type\ (Scan.lift (Args.embedded_position -- (types >> print_types))) (fn _ => fn ((t, pos), type_args) => (static_check ("type _Test_" ^ type_args ^ " = " ^ t ^ type_args, pos); scala_name (t ^ type_args))) #> Thy_Output.antiquotation_raw_embedded \<^binding>\scala_object\ (Scan.lift Args.embedded_position) (fn _ => fn (x, pos) => (static_check ("val _test_ = " ^ x, pos); scala_name x)) #> Thy_Output.antiquotation_raw_embedded \<^binding>\scala_method\ (Scan.lift (class -- Args.embedded_position -- types -- args)) (fn _ => fn (((class_context, (method, pos)), method_types), method_args) => let val class_types = (case class_context of SOME (_, Ts) => Ts | NONE => []); val def = "def _test_" ^ print_types (merge (op =) (method_types, class_types)); val def_context = (case class_context of NONE => def ^ " = " | SOME c => def ^ "(_this_ : " ^ print_class c ^ ") = _this_."); val source = def_context ^ method ^ method_args; val _ = static_check (source, pos); val text = (case class_context of NONE => method | SOME c => print_class c ^ "." ^ method); in scala_name text end)); end; end;