diff --git a/mlsource/MLCompiler/CodeTree/CODETREE_OPTIMISER.sml b/mlsource/MLCompiler/CodeTree/CODETREE_OPTIMISER.sml index eb03d493..5b2c815f 100644 --- a/mlsource/MLCompiler/CodeTree/CODETREE_OPTIMISER.sml +++ b/mlsource/MLCompiler/CodeTree/CODETREE_OPTIMISER.sml @@ -1,1414 +1,1414 @@ (* Copyright (c) 2012,13,15,17 David C.J. Matthews This library is free software; you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License version 2.1 as published by the Free Software Foundation. This library is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. You should have received a copy of the GNU Lesser General Public License along with this library; if not, write to the Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA *) functor CODETREE_OPTIMISER( structure BASECODETREE: BaseCodeTreeSig structure CODETREE_FUNCTIONS: CodetreeFunctionsSig structure REMOVE_REDUNDANT: sig type codetree type loadForm type codeUse val cleanProc : (codetree * codeUse list * (int -> loadForm) * int) -> codetree structure Sharing: sig type codetree = codetree and loadForm = loadForm and codeUse = codeUse end end structure SIMPLIFIER: sig type codetree and codeBinding and envSpecial val simplifier: { code: codetree, numLocals: int, maxInlineSize: int } -> (codetree * codeBinding list * envSpecial) * int * bool val specialToGeneral: codetree * codeBinding list * envSpecial -> codetree structure Sharing: sig type codetree = codetree and codeBinding = codeBinding and envSpecial = envSpecial end end structure DEBUG: DEBUGSIG structure PRETTY : PRETTYSIG structure BACKEND: sig type codetree type machineWord = Address.machineWord val codeGenerate: codetree * int * Universal.universal list -> (unit -> machineWord) * Universal.universal list structure Sharing : sig type codetree = codetree end end sharing BASECODETREE.Sharing = CODETREE_FUNCTIONS.Sharing = REMOVE_REDUNDANT.Sharing = SIMPLIFIER.Sharing = PRETTY.Sharing = BACKEND.Sharing ) : sig type codetree and envSpecial and codeBinding val codetreeOptimiser: codetree * Universal.universal list * int -> { numLocals: int, general: codetree, bindings: codeBinding list, special: envSpecial } structure Sharing: sig type codetree = codetree and envSpecial = envSpecial and codeBinding = codeBinding end end = struct open BASECODETREE open Address open CODETREE_FUNCTIONS open StretchArray infix 9 sub exception InternalError = Misc.InternalError (* Turn a list of fields to use into a filter for SetContainer. *) fun fieldsToFilter useList = let val maxDest = List.foldl Int.max ~1 useList val fields = BoolArray.array(maxDest+1, false) val _ = List.app(fn n => BoolArray.update(fields, n, true)) useList in BoolArray.vector fields end and filterToFields filter = BoolVector.foldri (fn (i, true, l) => i :: l | (_, _, l) => l) [] filter and setInFilter filter = BoolVector.foldl (fn (true, n) => n+1 | (false, n) => n) 0 filter (* Work-around for bug in bytevector equality. *) and boolVectorEq(a, b) = filterToFields a = filterToFields b fun buildFullTuple(filter, select) = let fun extArg(t, u) = if t = BoolVector.length filter then [] else if BoolVector.sub(filter, t) then select u :: extArg(t+1, u+1) else CodeZero :: extArg (t+1, u) in mkTuple(extArg(0, 0)) end (* When transforming code we only process one level and do not descend into sub-functions. *) local fun deExtract(Extract l) = l | deExtract _ = raise Misc.InternalError "deExtract" fun onlyFunction repEntry (Lambda{ body, isInline, name, closure, argTypes, resultType, localCount, recUse }) = SOME( Lambda { body = body, isInline = isInline, name = name, closure = map (deExtract o mapCodetree repEntry o Extract) closure, argTypes = argTypes, resultType = resultType, localCount = localCount, recUse = recUse } ) | onlyFunction repEntry code = repEntry code in fun mapFunctionCode repEntry = mapCodetree (onlyFunction repEntry) end local (* This transforms the body of a "small" recursive function replacing any reference to the arguments by the appropriate entry and the recursive calls themselves by either a Loop or a recursive call. *) fun mapCodeForFunctionRewriting(code, argMap, modVec, transformCall) = let fun repEntry(Extract(LoadArgument n)) = SOME(Extract(Vector.sub(argMap, n))) | repEntry(Eval { function = Extract LoadRecursive, argList, resultType }) = let (* Filter arguments to include only those that are changed and map any values we pass. They may include references to the parameters. *) fun mapArg((arg, argT)::rest, n) = if Vector.sub(modVec, n) then mapArg(rest, n+1) else (mapCode arg, argT) :: mapArg(rest, n+1) | mapArg([], _) = [] in SOME(transformCall(mapArg(argList, 0), resultType)) end | repEntry _ = NONE and mapCode code = mapFunctionCode repEntry code in mapCode code end in (* If we have a tail recursive function we can replace the tail calls by a loop. modVec indicates the arguments that have not changed. *) fun replaceTailRecursiveWithLoop(body, argTypes, modVec, nextAddress) = let (* We need to create local bindings for arguments that will change. Those that do not can be reused. *) local fun mapArgs((argT, use):: rest, n, decs, mapList) = if Vector.sub(modVec, n) then mapArgs (rest, n+1, decs, LoadArgument n :: mapList) else let val na = ! nextAddress before nextAddress := !nextAddress + 1 in mapArgs (rest, n+1, ({addr = na, value = mkLoadArgument n, use=use}, argT) :: decs, LoadLocal na :: mapList) end | mapArgs([], _, decs, mapList) = (List.rev decs, List.rev mapList) val (decs, mapList) = mapArgs(argTypes, 0, [], []) in val argMap = Vector.fromList mapList val loopArgs = decs end in BeginLoop { arguments = loopArgs, loop = mapCodeForFunctionRewriting(body, argMap, modVec, fn (l, _) => Loop l) } end (* If we have a small recursive function where some arguments are passed through unchanged we can transform it by extracting the stable arguments and only passing the changing arguments. The advantage is that this allows the stable arguments to be inserted inline which is important if they are functions. The canonical example is List.map. *) fun liftRecursiveFunction(body, argTypes, modVec, closureSize, name, resultType, localCount) = let local fun getArgs((argType, use)::rest, nArg, clCount, argCount, stable, change, mapList) = let (* This is the argument from the outer function. It is either added to the closure or passed to the inner function. *) val argN = LoadArgument nArg in if Vector.sub(modVec, nArg) then getArgs(rest, nArg+1, clCount+1, argCount, argN :: stable, change, LoadClosure clCount :: mapList) else getArgs(rest, nArg+1, clCount, argCount+1, stable, (Extract argN, argType, use) :: change, LoadArgument argCount :: mapList) end | getArgs([], _, _, _, stable, change, mapList) = (List.rev stable, List.rev change, List.rev mapList) in (* The stable args go into the closure. The changeable args are passed in. *) val (stableArgs, changeArgsAndTypes, mapList) = getArgs(argTypes, 0, closureSize, 0, [], [], []) val argMap = Vector.fromList mapList end val subFunction = Lambda { body = mapCodeForFunctionRewriting(body, argMap, modVec, fn (l, t) => Eval { function = Extract LoadRecursive, argList = l, resultType = t }), isInline = NonInline, (* Don't inline this function. *) name = name ^ "()", closure = List.tabulate(closureSize, fn n => LoadClosure n) @ stableArgs, argTypes = List.map (fn (_, t, u) => (t, u)) changeArgsAndTypes, resultType = resultType, localCount = localCount, recUse = [UseGeneral] } in Eval { function = subFunction, argList = map (fn (c, t, _) => (c, t)) changeArgsAndTypes, resultType = resultType } end end (* If the function arguments are used in a way that could be optimised the data structure represents it. *) datatype functionArgPattern = ArgPattTuple of { filter: BoolVector.vector, allConst: bool, fromFields: bool } (* ArgPattCurry is a list, one per level of application, of a list, one per argument of the pattern for that argument. *) | ArgPattCurry of functionArgPattern list list * functionArgPattern | ArgPattSimple (* Returns ArgPattCurry even if it is just a single application. *) local (* Control how we check for side-effects. *) datatype curryControl = CurryNoCheck | CurryCheck | CurryReorderable local open Address (* Return the width of a tuple. Returns 1 for non-tuples including datatypes where different variants could have different widths. Also returns a flag indicating if the value came from a constant. Constants are already tupled so there's no advantage in untupling them unless there are other non-constant arguments as well. *) fun findTuple(Tuple{fields, isVariant=false}) = (List.length fields, false) | findTuple(Constnt(w, _)) = if isShort w orelse flags (toAddress w) <> F_words then (1, false) else (Word.toInt(length (toAddress w)), true) | findTuple(Extract _) = (1, false) (* TODO: record this for variables *) | findTuple(Cond(_, t, e)) = let val (tl, tc) = findTuple t and (el, ec) = findTuple e in if tl = el then (tl, tc andalso ec) else (1, false) end | findTuple(Newenv(_, e)) = findTuple e | findTuple _ = (1, false) in fun mapArg c = let val (n, f) = findTuple c in if n <= 1 then ArgPattSimple else ArgPattTuple{filter=BoolVector.tabulate(n, fn _ => true), allConst=f, fromFields=false} end end fun useToPattern _ [] = ArgPattSimple | useToPattern checkCurry (hd::tl) = let (* Construct a possible pattern from the head. *) val p1 = case hd of UseApply(resl, arguments) => let (* If the result is also curried extend the list. *) val subCheck = case checkCurry of CurryCheck => CurryReorderable | c => c val (resultPatts, resultResult) = case useToPattern subCheck resl of ArgPattCurry l => l | tupleOrSimple => ([], tupleOrSimple) val thisArg = map mapArg arguments in (* If we have an argument that is a curried function we can safely apply it to the first argument even if that has a side-effect but we can't uncurry further than that because the behaviour could rely on a side-effect of the first application. *) if checkCurry = CurryReorderable andalso List.exists(not o reorderable) arguments then ArgPattSimple else ArgPattCurry(thisArg :: resultPatts, resultResult) end | UseField (n, _) => ArgPattTuple{filter=BoolVector.tabulate(n+1, fn m => m=n), allConst=false, fromFields=true} | _ => ArgPattSimple fun mergePattern(ArgPattCurry(l1, r1), ArgPattCurry(l2, r2)) = let (* Each argument list should be the same length. The length here is the number of arguments provided to this application. *) fun mergeArgLists(al1, al2) = ListPair.mapEq mergePattern (al1, al2) (* The currying lists could be different lengths because some applications could only partially apply it. It is essential not to assume more currying than the minimum so we stop with the shorter. *) val prefix = ListPair.map mergeArgLists (l1, l2) in if null prefix then ArgPattSimple else ArgPattCurry(prefix, mergePattern(r1, r2)) end | mergePattern(ArgPattTuple{filter=n1, allConst=c1, fromFields=f1}, ArgPattTuple{filter=n2, allConst=c2, fromFields=f2}) = (* If the tuples are different sizes we can't use a tuple. Unlike currying it would be safe to assume tupling where there isn't (unless the function is actually polymorphic). *) if boolVectorEq(n1, n2) then ArgPattTuple{filter=n1, allConst=c1 andalso c2, fromFields = f1 andalso f2} else if f1 andalso f2 then let open BoolVector val l1 = length n1 and l2 = length n2 fun safesub(n, v) = if n < length v then v sub n else false val union = tabulate(Int.max(l1, l2), fn n => safesub(n, n1) orelse safesub(n, n2)) in ArgPattTuple{filter=union, allConst=c1 andalso c2, fromFields = f1 andalso f2} end else ArgPattSimple | mergePattern _ = ArgPattSimple in case tl of [] => p1 | tl => mergePattern(p1, useToPattern checkCurry tl) end (* If the result is just a function where all the arguments are simple it's not actually curried. *) fun usageToPattern checkCurry use = case useToPattern checkCurry use of (* a as ArgPattCurry [s] => if List.all(fn ArgPattSimple => true | _ => false) s then ArgPattSimple else a |*) patt => patt in (* Decurrying involves reordering (f exp1) exp2 into code where any effects of evaluating exp2 are done before the application. That's only safe if either (f exp1) or exp2 have no side-effects and do not depend on references. In the case of the function body we can check that the body does not depend on any references (typically it's a lambda) but for function arguments we have to check how it is applied. *) val usageForFunctionBody = usageToPattern CurryNoCheck and usageForFunctionArg = usageToPattern CurryCheck (* To decide whether we want to detuple the argument we look to see if the function is ever applied to a tuple. This is rather different to currying where we only decurry if every application is to multiple arguments. This information is then merged with information about the arguments within the function. *) fun existTupling (use: codeUse list): functionArgPattern list = let val argListLists = List.foldl (fn (UseApply(_, args), l) => map mapArg args :: l | (_, l) => l) [] use fun orMerge [] = raise Empty | orMerge [hd] = hd | orMerge (hd1 :: hd2 :: tl) = let fun merge(a as ArgPattTuple _, _) = a | merge(_, b) = b in orMerge(ListPair.mapEq merge (hd1, hd2) :: tl) end in orMerge argListLists end (* If the result of a function contains a tuple but it is not detupled on every path, see if it is detupled on at least one. *) fun existDetupling(UseApply(resl, _) :: rest) = List.exists(fn UseField _ => true | _ => false) resl orelse existDetupling rest | existDetupling(_ :: rest) = existDetupling rest | existDetupling [] = false end (* Return a tuple if any of the branches returns a tuple. The idea is that if the body actually constructs a tuple on the heap on at least one branch it is probably worth attempting to detuple the result. *) fun bodyReturnsTuple (Tuple{fields, isVariant=false}) = ArgPattTuple{ filter=BoolVector.tabulate(List.length fields, fn _ => true), allConst=false, fromFields=false } | bodyReturnsTuple(Cond(_, t, e)) = ( case bodyReturnsTuple t of a as ArgPattTuple _ => a | _ => bodyReturnsTuple e ) | bodyReturnsTuple(Newenv(_, exp)) = bodyReturnsTuple exp | bodyReturnsTuple _ = ArgPattSimple (* If the usage indicates that the body of the function should be transformed these do the transformation. It is possible that each of these cases could apply and it would be possible to merge them all. For the moment keep them separate. If another of the cases applies this will be re-entered on a subsequent pass. *) fun detupleResult({ argTypes, name, resultType, closure, isInline, localCount, body, ...}: lambdaForm , filter, makeAddress) = (* The function returns a tuple or at least the uses of the function take apart a tuple. Transform it to take a container as an argument and put the result in there. *) let local fun mapArg f n ((t, _) :: tl) = (Extract(f n), t) :: mapArg f (n+1) tl | mapArg _ _ [] = [] in fun mapArgs f l = mapArg f 0 l end val mainAddress = makeAddress() and shimAddress = makeAddress() (* The main function performs the previous computation but puts the result into the container. We need to replace any recursive references with calls to the shim.*) local val recEntry = LoadClosure(List.length closure) fun doMap(Extract LoadRecursive) = SOME(Extract recEntry) | doMap _ = NONE in val transBody = mapFunctionCode doMap body end local val containerArg = Extract(LoadArgument(List.length argTypes)) val newBody = SetContainer{container = containerArg, tuple = transBody, filter=filter } val mainLambda: lambdaForm = { body = newBody, name = name, resultType=GeneralType, argTypes=argTypes @ [(GeneralType, [])], closure=closure @ [LoadLocal shimAddress], localCount=localCount + 1, isInline=isInline, recUse = [UseGeneral] } in val mainFunction = (mainAddress, mainLambda) end (* The shim function creates a container, passes it to the main function and then builds a tuple from the container. *) val shimBody = mkEnv( [Container{addr = 0, use = [], size = setInFilter filter, setter= Eval { function = Extract(LoadClosure 0), argList = mapArgs LoadArgument argTypes @ [(Extract(LoadLocal 0), GeneralType)], resultType = GeneralType } } ], buildFullTuple(filter, fn n => mkIndContainer(n, mkLoadLocal 0)) ) val shimLambda = { body = shimBody, name = name, argTypes = argTypes, closure = [LoadLocal mainAddress], resultType = resultType, isInline = Inline, localCount = 1, recUse = [UseGeneral] } val shimFunction = (shimAddress, shimLambda) in (shimLambda, [mainFunction, shimFunction]) end fun transformFunctionArgs({ argTypes, name, resultType, closure, isInline, localCount, body, ...} , usage, makeAddress) = (* Not curried - just a single argument. *) let (* We need to construct an inline "shim" function that has the same calling pattern as the original. This simply calls the transformed main function. We need to construct the arguments to call the transformed main function. That needs, for example, to unpack tuples and repack argument functions. We need to produce an argument map to transform the main function. This needs, for example, to pack the arguments into tuples. Then when the code is run through the simplifier the tuples will be optimised away. *) val localCounter = ref localCount fun mapPattern(ArgPattTuple{filter, allConst=false, ...} :: patts, n, m) = let val fieldList = filterToFields filter val (decs, args, mapList) = mapPattern(patts, n+1, m + setInFilter filter) val newAddr = ! localCounter before localCounter := ! localCounter + 1 val tuple = buildFullTuple(filter, fn u => mkLoadArgument(m+u)) val thisDec = Declar { addr = newAddr, use = [], value = tuple } (* Arguments for the call *) val thisArg = List.map(fn p => mkInd(p, mkLoadArgument n)) fieldList in (thisDec :: decs, thisArg @ args, LoadLocal newAddr :: mapList) end | mapPattern(ArgPattCurry(currying as [_], ArgPattTuple{allConst=false, filter, ...}) :: patts, n, m) = (* It's a function that returns a tuple. The function must not be curried because otherwise it returns a function not a tuple. *) let val (thisDec, thisArg, thisMap) = transformFunctionArgument(currying, [LoadArgument m], [LoadArgument n], SOME filter) val (decs, args, mapList) = mapPattern(patts, n+1, m+1) in (thisDec :: decs, thisArg :: args, thisMap :: mapList) end | mapPattern(ArgPattCurry(currying as firstArgSet :: _, _) :: patts, n, m) = (* Transform it if it's curried or if there is a tuple in the first arg. *) if (*List.length currying >= 2 orelse *) (* This transformation is unsafe. *) List.exists(fn ArgPattTuple{allConst=false, ...} => true | _ => false) firstArgSet then let val (thisDec, thisArg, thisMap) = transformFunctionArgument(currying, [LoadArgument m], [LoadArgument n], NONE) val (decs, args, mapList) = mapPattern(patts, n+1, m+1) in (thisDec :: decs, thisArg :: args, thisMap :: mapList) end else let val (decs, args, mapList) = mapPattern(patts, n+1, m+1) in (decs, Extract(LoadArgument n) :: args, LoadArgument m :: mapList) end | mapPattern(_ :: patts, n, m) = let val (decs, args, mapList) = mapPattern(patts, n+1, m+1) in (decs, Extract(LoadArgument n) :: args, LoadArgument m :: mapList) end | mapPattern([], _, _) = ([], [], []) and transformFunctionArgument(argumentArgs, loadPack, loadThisArg, filterOpt) = let (* Disable the transformation of curried arguments for the moment. This is unsafe. See Test146. The problem is that this transformation is only safe if the function is applied immediately to all the arguments. However the usage information is propagated so that if the result of the first application is bound to a variable and then that variable is applied it still appears as curried. *) val argumentArgs = [hd argumentArgs] (* We have a function that takes a series of curried argument. Change that so that the function takes a list of arguments. *) val newAddr = ! localCounter before localCounter := ! localCounter + 1 (* In the main function we are expecting to call the argument in a curried fashion. We need to construct a function that packages up the arguments and, when all of them have been provided, calls the actual argument. *) local fun curryPack([], fnclosure) = let (* We're ready to call the function. We now need to unpack any tupled arguments. *) fun mapArgs(c :: ctl, args) = let fun mapArg([], args) = mapArgs(ctl, args) | mapArg(ArgPattTuple{filter, allConst=false, ...} :: patts, arg :: argctl) = let val fields = filterToFields filter in List.map(fn p => (mkInd(p, Extract arg), GeneralType)) fields @ mapArg(patts, argctl) end | mapArg(_ :: patts, arg :: argctl) = (Extract arg, GeneralType) :: mapArg(patts, argctl) | mapArg(_, []) = raise InternalError "mapArgs: mismatch" in mapArg(c, args) end | mapArgs _ = [] val argList = mapArgs(argumentArgs, tl fnclosure) in case filterOpt of NONE => Eval { function = Extract(hd fnclosure), resultType = GeneralType, argList = argList } | SOME filter => (* We need a container here for the result. *) mkEnv( [ Container{addr=0, size=setInFilter filter, use=[UseGeneral], setter= Eval { function = Extract(hd fnclosure), resultType = GeneralType, argList = argList @ [(mkLoadLocal 0, GeneralType)] } } ], buildFullTuple(filter, fn n => mkIndContainer(n, mkLoadLocal 0)) ) end | curryPack(hd :: tl, fnclosure) = let val nArgs = List.length hd (* If this is the last then we need to include the container if required. *) val needContainer = case (tl, filterOpt) of ([], SOME _) => true | _ => false in Lambda { closure = fnclosure, isInline = Inline, name = name ^ "-P", resultType = GeneralType, argTypes = List.tabulate(nArgs, fn _ => (GeneralType, [UseGeneral])), localCount = if needContainer then 1 else 0, recUse = [], body = curryPack(tl, (* The closure for the next level is the current closure together with all the arguments at this level. *) List.tabulate(List.length fnclosure, fn n => LoadClosure n) @ List.tabulate(nArgs, LoadArgument)) } end in val packFn = curryPack(argumentArgs, loadPack) end val thisDec = Declar { addr = newAddr, use = [], value = packFn } fun argCount(ArgPattTuple{filter, allConst=false, ...}, m) = setInFilter filter + m | argCount(_, m) = m+1 local (* In the shim function, i.e. the inline function outside, we have a lambda that will be called when the main function wants to call its argument function. This is provided with all the arguments and so it has to call the actual argument, which is expected to be curried, an argument at a time. *) fun curryApply(hd :: tl, n, c) = let fun makeArgs(_, []) = [] | makeArgs(q, ArgPattTuple{filter, allConst=false, ...} :: args) = (buildFullTuple(filter, fn r => mkLoadArgument(r+q)), GeneralType) :: makeArgs(q + setInFilter filter, args) | makeArgs(q, _ :: args) = (mkLoadArgument q, GeneralType) :: makeArgs(q+1, args) val args = makeArgs(n, hd) in curryApply(tl, n + List.foldl argCount 0 hd, Eval{function=c, resultType = GeneralType, argList=args}) end | curryApply([], _, c) = c in val thisBody = curryApply (argumentArgs, 0, mkLoadClosure 0) end local (* We have one argument for each argument at each level of currying, or where we've expanded a tuple, one argument for each field. If the function is returning a tuple we have an extra argument for the container. *) val totalArgCount = List.foldl(fn (c, n) => n + List.foldl argCount 0 c) 0 argumentArgs + (case filterOpt of SOME _ => 1 | _ => 0) val functionBody = case filterOpt of NONE => thisBody | SOME filter => mkSetContainer(mkLoadArgument(totalArgCount-1), thisBody, filter) in val thisArg = Lambda { closure = loadThisArg, isInline = Inline, name = name ^ "-E", argTypes = List.tabulate(totalArgCount, fn _ => (GeneralType, [UseGeneral])), resultType = GeneralType, localCount = 0, recUse = [UseGeneral], body = functionBody } end in (thisDec, thisArg, LoadLocal newAddr) end val (extraBindings, transArgCode, argMapList) = mapPattern(usage, 0, 0) local (* Transform the body by replacing the arguments with the new arguments. *) val argMap = Vector.fromList argMapList (* If we have a recursive reference we have to replace it with a reference to the shim. *) val recEntry = LoadClosure(List.length closure) fun doMap(Extract(LoadArgument n)) = SOME(Extract(Vector.sub(argMap, n))) | doMap(Extract LoadRecursive) = SOME(Extract recEntry) | doMap _ = NONE in val transBody = mapFunctionCode doMap body end local (* The argument types for the main function have the tuples expanded, Functions are not affected. *) fun expand(ArgPattTuple{filter, allConst=false, ...}, _, r) = List.tabulate(setInFilter filter, fn _ => (GeneralType, [])) @ r | expand(_, a, r) = a :: r in val transArgTypes = ListPair.foldrEq expand [] (usage, argTypes) end (* Add the type information to the argument code. *) val transArgs = ListPair.mapEq(fn (c, (t, _)) => (c, t)) (transArgCode, transArgTypes) val mainAddress = makeAddress() and shimAddress = makeAddress() val transLambda = { body = mkEnv(extraBindings, transBody), name = name, argTypes = transArgTypes, closure = closure @ [LoadLocal shimAddress], resultType = resultType, isInline = isInline, localCount = ! localCounter, recUse = [UseGeneral] } (* Return the pair of functions. *) val mainFunction = (mainAddress, transLambda) val shimBody = Eval { function = Extract(LoadClosure 0), argList = transArgs, resultType = resultType } val shimLambda = { body = shimBody, name = name, argTypes = argTypes, closure = [LoadLocal mainAddress], resultType = resultType, isInline = Inline, localCount = 0, recUse = [UseGeneral] } val shimFunction = (shimAddress, shimLambda) (* TODO: We have two copies of the shim function here. *) in (shimLambda, [mainFunction, shimFunction]) end fun decurryFunction( { argTypes, name, resultType, closure, isInline, localCount, body as Lambda { argTypes=subArgTypes, resultType=subResultType, ... } , ...}, makeAddress) = (* Curried - just unwind one level this time. This case is normally dealt with by the front-end at least for fun bindings. *) let local fun mapArg f n ((t, _) :: tl) = (Extract(f n), t) :: mapArg f (n+1) tl | mapArg _ _ [] = [] in fun mapArgs f l = mapArg f 0 l end val mainAddress = makeAddress() and shimAddress = makeAddress() (* The main function calls the original body as a function. The body is a lambda which will contain references to the outer arguments but because we're just adding arguments these will be as before. *) (* We have to transform any recursive references to point to the shim. *) local val recEntry = LoadClosure(List.length closure) fun doMap(Extract LoadRecursive) = SOME(Extract recEntry) | doMap _ = NONE in val transBody = mapFunctionCode doMap body end val arg1Count = List.length argTypes val mainLambda = { body = Eval{ function = transBody, resultType = subResultType, argList = mapArgs (fn n => LoadArgument(n+arg1Count)) subArgTypes }, name = name, resultType = subResultType, closure = closure @ [LoadLocal shimAddress], isInline = isInline, localCount = localCount, argTypes = argTypes @ subArgTypes, recUse = [UseGeneral] } val mainFunction = (mainAddress, mainLambda) val shimInnerLambda = Lambda { (* The inner shim closure contains the main function and the outer arguments. *) closure = LoadClosure 0 :: List.tabulate(arg1Count, LoadArgument), body = Eval { function = Extract(LoadClosure 0), resultType = resultType, (* Calls main function with both sets of args. *) argList = mapArgs (fn n => LoadClosure(n+1)) argTypes @ mapArgs LoadArgument subArgTypes }, name = name ^ "-", resultType = subResultType, localCount = 0, isInline = Inline, argTypes = subArgTypes, recUse = [UseGeneral] } val shimOuterLambda = { body = shimInnerLambda, name = name, argTypes = argTypes, closure = [LoadLocal mainAddress], resultType = resultType, isInline = Inline, localCount = 0, recUse = [UseGeneral] } val shimFunction = (shimAddress, shimOuterLambda) in (shimOuterLambda: lambdaForm, [mainFunction, shimFunction]) end | decurryFunction _ = raise InternalError "decurryFunction" (* Process a Lambda slightly differently in different contexts. *) datatype lambdaContext = LCNormal | LCRecursive | LCImmediateCall (* Transforming a lambda may result in producing auxiliary functions that are in general mutually recursive. *) fun mapLambdaResult([], lambda) = lambda | mapLambdaResult(bindings, lambda) = mkEnv([RecDecs(map(fn(addr, lam) => {addr=addr, use=[], lambda=lam}) bindings)], lambda) fun optimise (context, use) (Lambda lambda) = SOME(mapLambdaResult(optLambda(context, use, lambda, LCNormal))) | optimise (context, use) (Newenv(envDecs, envExp)) = let fun mapExp mapUse = mapCodetree (optimise(context, mapUse)) fun mapbinding(Declar{value, addr, use}) = Declar{value=mapExp use value, addr=addr, use=use} | mapbinding(RecDecs l) = let fun mapRecDec({addr, lambda, use}, rest) = case optLambda(context, use, lambda, LCRecursive) of (bindings, Lambda lambdaRes) => (* Turn any bindings into extra mutually-recursive functions. *) {addr=addr, use = use, lambda = lambdaRes } :: map (fn (addr, res) => {addr=addr, use=use, lambda=res }) bindings @ rest | _ => raise InternalError "mapbinding: not lambda" in RecDecs(foldl mapRecDec [] l) end | mapbinding(NullBinding exp) = NullBinding(mapExp [UseGeneral] exp) | mapbinding(Container{addr, use, size, setter}) = Container{addr=addr, use=use, size=size, setter = mapExp [UseGeneral] setter} in SOME(Newenv(map mapbinding envDecs, mapExp use envExp)) end (* Immediate call to a function. We may be able to expand this inline unless it is recursive. *) | optimise (context, use) (Eval {function = Lambda lambda, argList, resultType}) = let val args = map (fn (c, t) => (optGeneral context c, t)) argList val argTuples = map #1 args val (bindings, newLambda) = optLambda(context, [UseApply(use, argTuples)], lambda, LCImmediateCall) val call = Eval { function=newLambda, argList=args, resultType = resultType } in SOME(mapLambdaResult(bindings, call)) end | optimise (context as { reprocess, ...}, use) (Eval {function = Cond(i, t, e), argList, resultType}) = let (* Transform "(if i then t else e) x" into "if i then t x else e x". This allows for other optimisations and inline expansion. *) (* We duplicate the function arguments which could cause the size of the code to blow-up if they involve complicated expressions. *) fun pushFunction l = mapCodetree (optimise(context, use)) (Eval{function=l, argList=argList, resultType=resultType}) in reprocess := true; SOME(Cond(i, pushFunction t, pushFunction e)) end | optimise (context, use) (Eval {function, argList, resultType}) = (* If nothing else we need to ensure that "use" is correctly set on the function and arguments and we don't simply pass the original. *) let val args = map (fn (c, t) => (optGeneral context c, t)) argList val argTuples = map #1 args in SOME( Eval{ function= mapCodetree (optimise (context, [UseApply(use, argTuples)])) function, argList=args, resultType = resultType }) end | optimise (context, use) (Indirect{base, offset, indKind = IndTuple}) = SOME(Indirect{base = mapCodetree (optimise(context, [UseField(offset, use)])) base, offset = offset, indKind = IndTuple}) | optimise (context, use) (code as Cond _) = (* If the result of the if-then-else is always taken apart as fields then we are better off taking it apart further down and putting the fields into a container on the stack. *) if List.all(fn UseField _ => true | _ => false) use then SOME(optFields(code, context, use)) else NONE | optimise (context, use) (code as BeginLoop _) = (* If the result of the loop is taken apart we should push this down as well. *) if List.all(fn UseField _ => true | _ => false) use then SOME(optFields(code, context, use)) else NONE | optimise _ _ = NONE and optGeneral context exp = mapCodetree (optimise(context, [UseGeneral])) exp and optLambda( { maxInlineSize, reprocess, makeAddr, ... }, contextUse, { body, name, argTypes, resultType, closure, localCount, isInline, recUse, ...}, lambdaContext) : (int * lambdaForm) list * codetree = (* Optimisations on lambdas. 1. A lambda that simply calls another function with all its own arguments can be replaced by a reference to the function provided the "function" is a side-effect-free expression. 2. Don't attempt to optimise inline functions that are exported. 3. Transform lambdas that take tuples as arguments or are curried or where an argument is a function with tupled or curried arguments into a pair of an inline function with the original argument set and a new "main" function with register/stack arguments. *) let (* The overall use of the function is the context plus the recursive use. *) val use = contextUse @ recUse (* Check if it's a call to another function with all the original arguments. This is really wanted when we are passing this lambda as an argument to another function and really only when we have produced a shim function that has been inline expanded. Otherwise this will be a "small" function and will be inline expanded when it's used. *) val replaceBody = case (body, lambdaContext = LCRecursive) of (Eval { function, argList, resultType=callresult }, false) => let fun argSequence((Extract(LoadArgument a), _) :: rest, b) = a = b andalso argSequence(rest, b+1) | argSequence([], _) = true | argSequence _ = false val argumentsMatch = argSequence(argList, 0) andalso ListPair.allEq(fn((_, a), (b, _)) => a = b) (argList, argTypes) andalso callresult = resultType in if not argumentsMatch then NONE else case function of (* This could be any function which has neither side-effects nor depends on a reference nor depends on another argument but if it has local variables they would have to be renumbered into the surrounding scope. In practice we're really only interested in simple cases that arise as a result of using a "shim" function created in the code below. *) c as Constnt _ => SOME c | Extract(LoadClosure addr) => SOME(Extract(List.nth(closure, addr))) | _ => NONE end | _ => NONE in case replaceBody of SOME c => ([], c) | NONE => if isInline = Inline andalso List.exists (fn UseExport => true | _ => false) use then let (* If it's inline any application of this will be optimised after inline expansion. We still apply any opimisations to the body at this stage because we will compile and code-generate a version for use if we want a "general" value. *) val addressAllocator = ref localCount val optContext = { makeAddr = fn () => (! addressAllocator) before addressAllocator := ! addressAllocator + 1, reprocess = reprocess, maxInlineSize = maxInlineSize } val optBody = mapCodetree (optimise(optContext, [UseGeneral])) body val lambdaRes = { body = optBody, isInline = isInline, name = name, closure = closure, argTypes = argTypes, resultType = resultType, recUse = recUse, localCount = !addressAllocator (* After optimising body. *) } in ([], Lambda lambdaRes) end else let (* Allocate any new addresses after the existing ones. *) val addressAllocator = ref localCount val optContext = { makeAddr = fn () => (! addressAllocator) before addressAllocator := ! addressAllocator + 1, reprocess = reprocess, maxInlineSize = maxInlineSize } val optBody = mapCodetree (optimise(optContext, [UseGeneral])) body (* See if this should be expanded inline. If we are calling the lambda immediately we try to expand it unless maxInlineSize is zero. We may not be able to expand it if it is recursive. (It may have been inside an inline function). *) val (inlineType, updatedBody, localCount) = case evaluateInlining(optBody, List.length argTypes, if maxInlineSize <> 0 andalso lambdaContext = LCImmediateCall then 1000 else FixedInt.toInt maxInlineSize) of NonRecursive => (Inline, optBody, ! addressAllocator) | TailRecursive bv => (Inline, replaceTailRecursiveWithLoop(optBody, argTypes, bv, addressAllocator), ! addressAllocator) | NonTailRecursive bv => if Vector.exists (fn n => n) bv then (Inline, liftRecursiveFunction( optBody, argTypes, bv, List.length closure, name, resultType, !addressAllocator), 0) else (NonInline, optBody, ! addressAllocator) (* All arguments have been modified *) | TooBig => (NonInline, optBody, ! addressAllocator) val lambda: lambdaForm = { body = updatedBody, name = name, argTypes = argTypes, closure = closure, resultType = resultType, isInline = inlineType, localCount = localCount, recUse = recUse } (* See if it should be transformed. We only do this if the function is not going to be inlined. If it is then there's no point because the transformation is going to be done as part of the inling process. Even if it's marked for inlining we may not actually call the function and instead pass it as an argument or return it as result but in that case transformation doesn't achieve anything because we are going to pass the untransformed "shim" function anyway. *) val (newLambda, bindings) = if isInline = NonInline then let val functionPattern = case usageForFunctionBody use of ArgPattCurry(arg1 :: arg2 :: moreArgs, res) => (* The function is always called with at least two curried arguments. We can decurry the function if the body is applicative - typically if it's a lambda - but not if applying the body would have a side-effect. We only do it one level at this stage. If it's curried more than that we'll come here again. *) (* In order to get the types we restrict this to the case of a body that is a lambda. The result is a function and therefore ArgPattSimple unless we are using up all the args. *) if (*reorderable body*) case updatedBody of Lambda _ => true | _ => false then ArgPattCurry([arg1, arg2], if null moreArgs then res else ArgPattSimple) else ArgPattCurry([arg1], ArgPattSimple) | usage => usage val argPatterns = map (usageForFunctionArg o #2) argTypes (* fullArgPattern is a list, one per level of currying, of a list, one per argument of the patterns. resultPattern is used to detect whether the result is a tuple that is taken apart. *) val (fullArgPattern, resultPattern) = case functionPattern of ArgPattCurry(_ :: rest, resPattern) => let (* The function is always applied at least to the first set of arguments. (It's never just passed). Merge the applications of the function with the use of the arguments. Return the usage within the function unless the function takes apart a tuple but no application passes in a tuple. *) fun merge(ArgPattTuple _, argUse as ArgPattTuple _) = argUse | merge(_, ArgPattTuple _) = ArgPattSimple | merge(_, argUse) = argUse val mergedArgs = (ListPair.mapEq merge (existTupling use, argPatterns)) :: rest (* *) val mergedResult = case (bodyReturnsTuple updatedBody, resPattern) of (bodyTuple as ArgPattTuple _, ArgPattSimple) => if existDetupling use then bodyTuple else ArgPattSimple | _ => resPattern in (mergedArgs, mergedResult) end | _ => (* Not called: either exported or passed as a value. *) (* This previously tried to see whether the body returned a tuple if the function was exported. This caused an infinite loop (see Tests/Succeed/Test164.ML) and anyway doesn't seem to optimise the cases we want. *) ([], ArgPattSimple) in case (fullArgPattern, resultPattern) of (_ :: _ :: _, _) => (* Curried *) ( reprocess := true; decurryFunction(lambda, makeAddr)) | (_, ArgPattTuple {filter, ...}) => (* Result is a tuple *) ( reprocess := true; detupleResult(lambda, filter, makeAddr)) | (first :: _, _) => let fun checkArg (ArgPattTuple{allConst=false, ...}) = true (* Function has at least one tupled arg. *) | checkArg (ArgPattCurry([_], ArgPattTuple{allConst=false, ...})) = true (* Function has an arg that is a function that returns a tuple. It must not be curried otherwise it returns a function not a tuple. *) (* This transformation is unsafe. See comment in transformFunctionArgument above. *) (*| checkArg (ArgPattCurry(_ :: _ :: _, _)) = true *) (* Function has an arg that is a curried function. *) | checkArg (ArgPattCurry(firstArgSet :: _, _)) = (* Function has an arg that is a function that takes a tuple in its first argument set. *) List.exists(fn ArgPattTuple{allConst=false, ...} => true | _ => false) firstArgSet | checkArg _ = false in (* It isn't curried - look at the arguments. *) if List.exists checkArg first then ( reprocess := true; transformFunctionArgs(lambda, first, makeAddr) ) else (lambda, []) end | _ => (lambda, []) end else (lambda, []) in (* If this is to be inlined but was not before we may need to reprocess. We don't reprocess if this is only exported. If it's only exported we're not going to expand it within this code and we can end up with repeated processing. *) if #isInline newLambda = Inline andalso isInline = NonInline andalso (case use of [UseExport] => false | _ => true) then reprocess := true else (); (bindings, Lambda newLambda) end end and optFields (code, context as { reprocess, makeAddr, ...}, use) = let (* We have an if-then-else or a loop whose result is only ever taken apart. We push this down. *) (* Find the fields that are used. Not all may be. *) local val maxField = List.foldl(fn (UseField(f, _), m) => Int.max(f, m) | (_, m) => m) 0 use val fieldUse = BoolArray.array(maxField+1, false) val _ = List.app(fn UseField(f, _) => BoolArray.update(fieldUse, f, true) | _ => ()) use in val maxField = maxField val useList = BoolArray.foldri (fn (i, true, l) => i :: l | (_, _, l) => l) [] fieldUse end fun pushContainer(Cond(ifpt, thenpt, elsept), leafFn) = Cond(ifpt, pushContainer(thenpt, leafFn), pushContainer(elsept, leafFn)) | pushContainer(Newenv(decs, exp), leafFn) = Newenv(decs, pushContainer(exp, leafFn)) | pushContainer(BeginLoop{loop, arguments}, leafFn) = (* If we push it through a BeginLoop we MUST then push it through anything that could contain the Loop i.e. Cond, Newenv, Handle. *) BeginLoop{loop = pushContainer(loop, leafFn), arguments=arguments} | pushContainer(l as Loop _, _) = l (* Within a BeginLoop only the non-Loop leaves return values. Loop entries go back to the BeginLoop so these are unchanged. *) | pushContainer(Handle{exp, handler, exPacketAddr}, leafFn) = Handle{exp=pushContainer(exp, leafFn), handler=pushContainer(handler, leafFn), exPacketAddr=exPacketAddr} | pushContainer(tuple, leafFn) = leafFn tuple (* Anything else. *) val () = reprocess := true in case useList of [offset] => (* We only want a single field. Push down an Indirect. *) let (* However the context still requires a tuple. We need to reconstruct one with unused fields set to zero. They will be filtered out later by the simplifier pass. *) val field = optGeneral context (pushContainer(code, fn t => mkInd(offset, t))) fun mkFields n = if n = offset then field else CodeZero in Tuple{ fields = List.tabulate(offset+1, mkFields), isVariant = false } end | _ => let (* We require a container. *) val containerAddr = makeAddr() val width = List.length useList val loadContainer = Extract(LoadLocal containerAddr) fun setContainer tuple = (* At the leaf set the container. *) SetContainer{container = loadContainer, tuple = tuple, filter = fieldsToFilter useList } val setCode = optGeneral context (pushContainer(code, setContainer)) val makeContainer = Container{addr=containerAddr, use=[], size=width, setter=setCode} (* The context requires a tuple of the original width. We need to add dummy fields where necessary. *) val container = if width = maxField+1 then mkTupleFromContainer(containerAddr, width) else let fun mkField(n, m, hd::tl) = if n = hd then mkIndContainer(m, loadContainer) :: mkField(n+1, m+1, tl) else CodeZero :: mkField(n+1, m, hd::tl) | mkField _ = [] in Tuple{fields = mkField(0, 0, useList), isVariant=false} end in mkEnv([makeContainer], container) end end (* TODO: convert "(if a then b else c) (args)" into if a then b(args) else c(args). This would allow for possible inlining and also passing information about call patterns. *) (* Once all the inlining is done we look for functions that can be compiled immediately. These are either functions with no free variables or functions where every use is a call, as opposed to being passed or returned as a closure. Functions that have free variables but are called can be lambda-lifted where the free variables are turned into extra parameters. The advantage compared with using a static-link or a closure on the stack is that they can be fully tail-recursive. With a static-link or stack closure the free variables have to remain on the stack until the function returns. *) fun lambdaLiftAndConstantFunction(code, debugSwitches, numLocals) = let val needReprocess = ref false (* At the moment this just code-generates immediately any lambdas without free-variables. The idea is to that we will get a constant which can then be inserted directly in references to the function. In general this takes a list of mutually recursive functions which can be code- generated immediately if all the free variables are other functions in the list. The simplifier has separated mutually recursive bindings into strongly connected components so we can consider the list as a single entity. *) fun processLambdas lambdaList = let (* First process the bodies of the functions. *) val needed = ! needReprocess val _ = needReprocess := false; val transLambdas = map (fn {lambda={body, isInline, name, closure, argTypes, resultType, localCount, recUse}, use, addr} => {lambda={body=mapChecks body, isInline=isInline, name=name, closure=closure, argTypes=argTypes, resultType=resultType, localCount=localCount, recUse=recUse}, use=use, addr=addr}) lambdaList val theseTransformed = ! needReprocess val _ = if needed then needReprocess := true else () fun hasFreeVariables{lambda={closure, ...}, ...} = let fun notInLambdas(LoadLocal lAddr) = (* A local is allowed if it only refers to another lambda. *) not (List.exists (fn {addr, ...} => addr = lAddr) lambdaList) | notInLambdas _ = true (* Anything else is not allowed. *) in List.exists notInLambdas closure end in if theseTransformed orelse List.exists (fn {lambda={isInline=Inline, ...}, ...} => true | _ => false) lambdaList orelse List.exists hasFreeVariables lambdaList (* If we have transformed any of the bodies we need to reprocess so defer any code-generation. Don't CG it if it is inline, or perhaps if it is inline and exported. Don't CG it if it has free variables. We still need to examine the bodies of the functions. *) then (transLambdas, []) else let (* Construct code to declare the functions and extract the values. *) val tupleFields = map (fn {addr, ...} => Extract(LoadLocal addr)) transLambdas val decsAndTuple = Newenv([RecDecs transLambdas], mkTuple tupleFields) val maxLocals = List.foldl(fn ({addr, ...}, n) => Int.max(addr, n)) 0 transLambdas val (code, props) = BACKEND.codeGenerate(decsAndTuple, maxLocals + 1, debugSwitches) val resultConstnt = Constnt(code(), props) fun getResults([], _) = [] | getResults({addr, use, ...} :: tail, n) = Declar {value=mkInd(n, resultConstnt), addr=addr, use=use} :: getResults(tail, n+1) val () = needReprocess := true in ([], getResults(transLambdas, 0)) end end and runChecks (Lambda (lambda as { isInline=NonInline, closure=[], ... })) = ( (* Bare lambda. *) case processLambdas[{lambda=lambda, use = [], addr = 0}] of ([{lambda=unCGed, ...}], []) => SOME(Lambda unCGed) | ([], [Declar{value, ...}]) => SOME value | _ => raise InternalError "processLambdas" ) | runChecks (Newenv(bindings, exp)) = let (* We have a block of bindings. Are any of them functions that are only ever called? *) fun checkBindings(Declar{value=Lambda lambda, addr, use}, tail) = ( (* Process this lambda and extract the result. *) case processLambdas[{lambda=lambda, use = use, addr = addr}] of ([{lambda=unCGed, use, addr}], []) => Declar{value=Lambda unCGed, use=use, addr=addr} :: tail | ([], cgedDec) => cgedDec @ tail | _ => raise InternalError "checkBindings" ) | checkBindings(Declar{value, addr, use}, tail) = Declar{value=mapChecks value, addr=addr, use=use} :: tail | checkBindings(RecDecs l, tail) = let val (notConsts, asConsts) = processLambdas l in asConsts @ (if null notConsts then [] else [RecDecs notConsts]) @ tail end | checkBindings(NullBinding exp, tail) = NullBinding(mapChecks exp) :: tail | checkBindings(Container{addr, use, size, setter}, tail) = Container{addr=addr, use=use, size=size, setter=mapChecks setter} :: tail in SOME(Newenv((List.foldr checkBindings [] bindings), mapChecks exp)) end | runChecks _ = NONE and mapChecks c = mapCodetree runChecks c in (mapCodetree runChecks code, numLocals, !needReprocess) end (* Main optimiser and simplifier loop. *) fun codetreeOptimiser(code, debugSwitches, numLocals) = let fun topLevel _ = raise InternalError "top level reached in optimiser" val maxInlineSize = DEBUG.getParameter DEBUG.maxInlineSizeTag debugSwitches fun processTree (code, nLocals, optAgain) = let (* First run the simplifier. Among other things this does inline expansion and if it does any we at least need to run cleanProc on the code so it will have set simpAgain. *) val (simpCode, simpCount, simpAgain) = - SIMPLIFIER.simplifier{code=code, numLocals=nLocals, maxInlineSize=maxInlineSize} + SIMPLIFIER.simplifier{code=code, numLocals=nLocals, maxInlineSize=FixedInt.toInt maxInlineSize} in if optAgain orelse simpAgain then let (* Identify usage information and remove redundant code. *) val printCodeTree = DEBUG.getParameter DEBUG.codetreeTag debugSwitches and compilerOut = PRETTY.getCompilerOutput debugSwitches val simpCode = SIMPLIFIER.specialToGeneral simpCode val () = if printCodeTree then compilerOut(PRETTY.PrettyString "Output of simplifier") else () val () = if printCodeTree then compilerOut (BASECODETREE.pretty simpCode) else () val preOptCode = REMOVE_REDUNDANT.cleanProc(simpCode, [UseExport], topLevel, simpCount) (* Print the code with the use information before it goes into the optimiser. *) val () = if printCodeTree then compilerOut(PRETTY.PrettyString "Output of cleaner") else () val () = if printCodeTree then compilerOut (BASECODETREE.pretty preOptCode) else () val reprocess = ref false (* May be set in the optimiser *) (* Allocate any new addresses after the existing ones. *) val addressAllocator = ref simpCount fun makeAddr() = (! addressAllocator) before addressAllocator := ! addressAllocator + 1 val optContext = { makeAddr = makeAddr, reprocess = reprocess, maxInlineSize = maxInlineSize } (* Optimise the code, rewriting it as necessary. *) val optCode = mapCodetree (optimise(optContext, [UseExport])) preOptCode val (llCode, llCount, llAgain) = (* If we have optimised it or the simplifier has run something that it wants to run again we must rerun these before we try to generate any code. *) if ! reprocess (* Re-optimise *) orelse simpAgain (* The simplifier wants to run again on this. *) then (optCode, ! addressAllocator, ! reprocess) else (* We didn't detect any inlineable functions. Check for lambda-lifting. *) lambdaLiftAndConstantFunction(optCode, debugSwitches, ! addressAllocator) (* Print the code after the optimiser. *) val () = if printCodeTree then compilerOut(PRETTY.PrettyString "Output of optimiser") else () val () = if printCodeTree then compilerOut (BASECODETREE.pretty llCode) else () in (* Rerun the simplifier at least. *) processTree(llCode, llCount, llAgain) end else (simpCode, simpCount) (* We're done *) end val (postOptCode, postOptCount) = processTree(code, numLocals, true (* Once at least *)) val (rGeneral, rDecs, rSpec) = postOptCode in { numLocals = postOptCount, general = rGeneral, bindings = rDecs, special = rSpec } end structure Sharing = struct type codetree = codetree and envSpecial = envSpecial and codeBinding = codeBinding end end; diff --git a/mlsource/MLCompiler/CodeTree/CODETREE_SIMPLIFIER.sml b/mlsource/MLCompiler/CodeTree/CODETREE_SIMPLIFIER.sml index 4bc9c3b7..3fb254fe 100644 --- a/mlsource/MLCompiler/CodeTree/CODETREE_SIMPLIFIER.sml +++ b/mlsource/MLCompiler/CodeTree/CODETREE_SIMPLIFIER.sml @@ -1,1747 +1,1747 @@ (* Copyright (c) 2013, 2016-17, 2020 David C.J. Matthews This library is free software; you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License version 2.1 as published by the Free Software Foundation. This library is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. You should have received a copy of the GNU Lesser General Public License along with this library; if not, write to the Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA *) (* This is a cut-down version of the optimiser which simplifies the code but does not apply any heuristics. It follows chained bindings, in particular through tuples, folds constants expressions involving built-in functions, expands inline functions that have previously been marked as inlineable. It does not detect small functions that can be inlined nor does it code-generate functions without free variables. *) functor CODETREE_SIMPLIFIER( structure BASECODETREE: BaseCodeTreeSig structure CODETREE_FUNCTIONS: CodetreeFunctionsSig structure REMOVE_REDUNDANT: sig type codetree type loadForm type codeUse val cleanProc : (codetree * codeUse list * (int -> loadForm) * int) -> codetree structure Sharing: sig type codetree = codetree and loadForm = loadForm and codeUse = codeUse end end structure DEBUG: DEBUGSIG sharing BASECODETREE.Sharing = CODETREE_FUNCTIONS.Sharing = REMOVE_REDUNDANT.Sharing ) : sig type codetree and codeBinding and envSpecial val simplifier: { code: codetree, numLocals: int, maxInlineSize: int } -> (codetree * codeBinding list * envSpecial) * int * bool val specialToGeneral: codetree * codeBinding list * envSpecial -> codetree structure Sharing: sig type codetree = codetree and codeBinding = codeBinding and envSpecial = envSpecial end end = struct open BASECODETREE open Address open CODETREE_FUNCTIONS open BuiltIns exception InternalError = Misc.InternalError exception RaisedException (* The bindings are held internally as a reversed list. This is really only a check that the reversed and forward lists aren't confused. *) datatype revlist = RevList of codeBinding list type simpContext = { lookupAddr: loadForm -> envGeneral * envSpecial, enterAddr: int * (envGeneral * envSpecial) -> unit, nextAddress: unit -> int, reprocess: bool ref, maxInlineSize: int } fun envGeneralToCodetree(EnvGenLoad ext) = Extract ext | envGeneralToCodetree(EnvGenConst w) = Constnt w fun mkDec (laddr, res) = Declar{value = res, addr = laddr, use=[]} fun mkEnv([], exp) = exp | mkEnv(decs, exp as Extract(LoadLocal loadAddr)) = ( (* A common case is where we have a binding as the last item and then a load of that binding. Reduce this so other optimisations are possible. This is still something of a special case that could/should be generalised. *) case List.last decs of Declar{addr=decAddr, value, ... } => if loadAddr = decAddr then mkEnv(List.take(decs, List.length decs - 1), value) else Newenv(decs, exp) | _ => Newenv(decs, exp) ) | mkEnv(decs, exp) = Newenv(decs, exp) fun isConstnt(Constnt _) = true | isConstnt _ = false (* Wrap up the general, bindings and special value as a codetree node. The special entry is discarded except for Constnt entries which are converted to ConstntWithInline. That allows any inlineable code to be carried forward to later passes. *) fun specialToGeneral(g, RevList(b as _ :: _), s) = mkEnv(List.rev b, specialToGeneral(g, RevList [], s)) | specialToGeneral(Constnt(w, p), RevList [], s) = Constnt(w, setInline s p) | specialToGeneral(g, RevList [], _) = g (* Convert a constant to a fixed value. Used in some constant folding. *) val toFix: machineWord -> FixedInt.int = FixedInt.fromInt o Word.toIntX o toShort local val ffiSizeFloat: unit -> word = RunCall.rtsCallFast1 "PolySizeFloat" and ffiSizeDouble: unit -> word = RunCall.rtsCallFast1 "PolySizeDouble" in (* If we have a constant index value we convert that into a byte offset. We need to know the size of the item on this platform. We have to make this check when we actually compile the code because the interpreted version will generally be run on a platform different from the one the pre-built compiler was compiled on. The ML word length will be the same because we have separate pre-built compilers for 32 and 64-bit. *) fun getMultiplier (LoadStoreMLWord _) = RunCall.bytesPerWord | getMultiplier (LoadStoreMLByte _) = 0w1 | getMultiplier LoadStoreC8 = 0w1 | getMultiplier LoadStoreC16 = 0w2 | getMultiplier LoadStoreC32 = 0w4 | getMultiplier LoadStoreC64 = 0w8 | getMultiplier LoadStoreCFloat = ffiSizeFloat() | getMultiplier LoadStoreCDouble = ffiSizeDouble() | getMultiplier LoadStoreUntaggedUnsigned = RunCall.bytesPerWord end fun simplify(c, s) = mapCodetree (simpGeneral s) c (* Process the codetree to return a codetree node. This is used when we don't want the special case. *) and simpGeneral { lookupAddr, ...} (Extract ext) = let val (gen, spec) = lookupAddr ext in SOME(specialToGeneral(envGeneralToCodetree gen, RevList [], spec)) end | simpGeneral context (Newenv envArgs) = SOME(specialToGeneral(simpNewenv(envArgs, context, RevList []))) | simpGeneral context (Lambda lambda) = SOME(Lambda(#1(simpLambda(lambda, context, NONE, NONE)))) | simpGeneral context (Eval {function, argList, resultType}) = SOME(specialToGeneral(simpFunctionCall(function, argList, resultType, context, RevList[]))) (* BuiltIn0 functions can't be processed specially. *) | simpGeneral context (Unary{oper, arg1}) = SOME(specialToGeneral(simpUnary(oper, arg1, context, RevList []))) | simpGeneral context (Binary{oper, arg1, arg2}) = SOME(specialToGeneral(simpBinary(oper, arg1, arg2, context, RevList []))) | simpGeneral context (Arbitrary{oper=ArbCompare test, shortCond, arg1, arg2, longCall}) = SOME(specialToGeneral(simpArbitraryCompare(test, shortCond, arg1, arg2, longCall, context, RevList []))) | simpGeneral context (Arbitrary{oper=ArbArith arith, shortCond, arg1, arg2, longCall}) = SOME(specialToGeneral(simpArbitraryArith(arith, shortCond, arg1, arg2, longCall, context, RevList []))) | simpGeneral context (AllocateWordMemory {numWords, flags, initial}) = SOME(specialToGeneral(simpAllocateWordMemory(numWords, flags, initial, context, RevList []))) | simpGeneral context (Cond(condTest, condThen, condElse)) = SOME(specialToGeneral(simpIfThenElse(condTest, condThen, condElse, context, RevList []))) | simpGeneral context (Tuple { fields, isVariant }) = SOME(specialToGeneral(simpTuple(fields, isVariant, context, RevList []))) | simpGeneral context (Indirect{ base, offset, indKind }) = SOME(specialToGeneral(simpFieldSelect(base, offset, indKind, context, RevList []))) | simpGeneral context (SetContainer{container, tuple, filter}) = let val optCont = simplify(container, context) val (cGen, cDecs, cSpec) = simpSpecial(tuple, context, RevList []) in case cSpec of (* If the tuple is a local binding it is simpler to pick it up from the "special" entry. *) EnvSpecTuple(size, recEnv) => let val fields = List.tabulate(size, envGeneralToCodetree o #1 o recEnv) in SOME(simpPostSetContainer(optCont, Tuple{isVariant=false, fields=fields}, cDecs, filter)) end | _ => SOME(simpPostSetContainer(optCont, cGen, cDecs, filter)) end | simpGeneral (context as { enterAddr, nextAddress, reprocess, ...}) (BeginLoop{loop, arguments, ...}) = let val didReprocess = ! reprocess (* To see if we really need the loop first try simply binding the arguments and process it. It's often the case that if one or more arguments is a constant that the looping case will be eliminated. *) val withoutBeginLoop = simplify(mkEnv(List.map (Declar o #1) arguments, loop), context) fun foldLoop f n (Loop l) = f(l, n) | foldLoop f n (Newenv(_, exp)) = foldLoop f n exp | foldLoop f n (Cond(_, t, e)) = foldLoop f (foldLoop f n t) e | foldLoop f n (Handle {handler, ...}) = foldLoop f n handler | foldLoop f n (SetContainer{tuple, ...}) = foldLoop f n tuple | foldLoop _ n _ = n (* Check if the Loop instruction is there. This assumes that these are the only tail-recursive cases. *) val hasLoop = foldLoop (fn _ => true) false in if not (hasLoop withoutBeginLoop) then SOME withoutBeginLoop else let (* Reset "reprocess". It may have been set in the withoutBeginLoop that's not the code we're going to return. *) val () = reprocess := didReprocess (* We need the BeginLoop. Create new addresses for the arguments. *) fun declArg({addr, value, use, ...}, typ) = let val newAddr = nextAddress() in enterAddr(addr, (EnvGenLoad(LoadLocal newAddr), EnvSpecNone)); ({addr = newAddr, value = simplify(value, context), use = use }, typ) end (* Now look to see if the (remaining) loops have any arguments that do not change. Do this after processing because we could be eliminating other loops that may change the arguments. *) val declArgs = map declArg arguments val beginBody = simplify(loop, context) local fun argsMatch((Extract (LoadLocal argNo), _), ({addr, ...}, _)) = argNo = addr | argsMatch _ = false fun checkLoopArgs(loopArgs, checks) = let fun map3(loopA :: loopArgs, decA :: decArgs, checkA :: checkArgs) = (argsMatch(loopA, decA) andalso checkA) :: map3(loopArgs, decArgs, checkArgs) | map3 _ = [] in map3(loopArgs, declArgs, checks) end in val checkList = foldLoop checkLoopArgs (map (fn _ => true) arguments) beginBody end in if List.exists (fn l => l) checkList then let (* Turn the original arguments into bindings. *) local fun argLists(true, (arg, _), (tArgs, fArgs)) = (Declar arg :: tArgs, fArgs) | argLists(false, arg, (tArgs, fArgs)) = (tArgs, arg :: fArgs) in val (unchangedArgs, filteredDeclArgs) = ListPair.foldrEq argLists ([], []) (checkList, declArgs) end fun changeLoops (Loop loopArgs) = let val newArgs = ListPair.foldrEq(fn (false, arg, l) => arg :: l | (true, _, l) => l) [] (checkList, loopArgs) in Loop newArgs end | changeLoops(Newenv(decs, exp)) = Newenv(decs, changeLoops exp) | changeLoops(Cond(i, t, e)) = Cond(i, changeLoops t, changeLoops e) | changeLoops(Handle{handler, exp, exPacketAddr}) = Handle{handler=changeLoops handler, exp=exp, exPacketAddr=exPacketAddr} | changeLoops(SetContainer{tuple, container, filter}) = SetContainer{tuple=changeLoops tuple, container=container, filter=filter} | changeLoops code = code val beginBody = simplify(changeLoops loop, context) (* Reprocess because we've lost any special part from the arguments that haven't changed. *) val () = reprocess := true in SOME(mkEnv(unchangedArgs, BeginLoop {loop=beginBody, arguments=filteredDeclArgs})) end else SOME(BeginLoop {loop=beginBody, arguments=declArgs}) end end | simpGeneral context (TagTest{test, tag, maxTag}) = ( case simplify(test, context) of Constnt(testResult, _) => if isShort testResult andalso toShort testResult = tag then SOME CodeTrue else SOME CodeFalse | sTest => SOME(TagTest{test=sTest, tag=tag, maxTag=maxTag}) ) | simpGeneral context (LoadOperation{kind, address}) = let (* Try to move constants out of the index. *) val (genAddress, RevList decAddress) = simpAddress(address, getMultiplier kind, context) (* If the base address and index are constant and this is an immutable load we can do this at compile time. *) val result = case (genAddress, kind) of ({base=Constnt(baseAddr, _), index=NONE, offset}, LoadStoreMLWord _) => if isShort baseAddr then LoadOperation{kind=kind, address=genAddress} else let (* Ignore the "isImmutable" flag and look at the immutable status of the memory. Check that this is a word object and that the offset is within range. The code for Vector.sub, for example, raises an exception if the index is out of range but still generates the (unreachable) indexing code. *) val addr = toAddress baseAddr val wordOffset = offset div RunCall.bytesPerWord in if isMutable addr orelse not(isWords addr) orelse wordOffset >= length addr then LoadOperation{kind=kind, address=genAddress} else Constnt(toMachineWord(loadWord(addr, wordOffset)), []) end | ({base=Constnt(baseAddr, _), index=NONE, offset}, LoadStoreMLByte _) => if isShort baseAddr then LoadOperation{kind=kind, address=genAddress} else let val addr = toAddress baseAddr val wordOffset = offset div RunCall.bytesPerWord in if isMutable addr orelse not(isBytes addr) orelse wordOffset >= length addr then LoadOperation{kind=kind, address=genAddress} else Constnt(toMachineWord(loadByte(addr, offset)), []) end | ({base=Constnt(baseAddr, _), index=NONE, offset}, LoadStoreUntaggedUnsigned) => if isShort baseAddr then LoadOperation{kind=kind, address=genAddress} else let val addr = toAddress baseAddr (* We don't currently have loadWordUntagged in Address but it's only ever used to load the string length word so we can use that. *) in if isMutable addr orelse not(isBytes addr) orelse offset <> 0w0 then LoadOperation{kind=kind, address=genAddress} else Constnt(toMachineWord(String.size(RunCall.unsafeCast addr)), []) end | _ => LoadOperation{kind=kind, address=genAddress} in SOME(mkEnv(List.rev decAddress, result)) end | simpGeneral context (StoreOperation{kind, address, value}) = let val (genAddress, decAddress) = simpAddress(address, getMultiplier kind, context) val (genValue, RevList decValue, _) = simpSpecial(value, context, decAddress) in SOME(mkEnv(List.rev decValue, StoreOperation{kind=kind, address=genAddress, value=genValue})) end | simpGeneral (context as {reprocess, ...}) (BlockOperation{kind, sourceLeft, destRight, length}) = let val multiplier = case kind of BlockOpMove{isByteMove=false} => RunCall.bytesPerWord | BlockOpMove{isByteMove=true} => 0w1 | BlockOpEqualByte => 0w1 | BlockOpCompareByte => 0w1 val (genSrcAddress, RevList decSrcAddress) = simpAddress(sourceLeft, multiplier, context) val (genDstAddress, RevList decDstAddress) = simpAddress(destRight, multiplier, context) val (genLength, RevList decLength, _) = simpSpecial(length, context, RevList []) (* If we have a short length move we're better doing it as a sequence of loads and stores. This is particularly useful with string concatenation. Small here means three or less. Four and eight byte moves are handled as single instructions in the code-generator provided the alignment is correct. *) val shortLength = case genLength of Constnt(lenConst, _) => if isShort lenConst then let val l = toShort lenConst in if l <= 0w3 then SOME l else NONE end else NONE | _ => NONE val combinedDecs = List.rev decSrcAddress @ List.rev decDstAddress @ List.rev decLength val operation = case (shortLength, kind) of (SOME length, BlockOpMove{isByteMove}) => let val _ = reprocess := true (* Frequently the source will be a constant. *) val {base=baseSrc, index=indexSrc, offset=offsetSrc} = genSrcAddress and {base=baseDst, index=indexDst, offset=offsetDst} = genDstAddress (* We don't know if the source is immutable but the destination definitely isn't *) val moveKind = if isByteMove then LoadStoreMLByte{isImmutable=false} else LoadStoreMLWord{isImmutable=false} fun makeMoves offset = if offset = length then [] else NullBinding( StoreOperation{kind=moveKind, address={base=baseDst, index=indexDst, offset=offsetDst+offset*multiplier}, value=LoadOperation{kind=moveKind, address={base=baseSrc, index=indexSrc, offset=offsetSrc+offset*multiplier}}}) :: makeMoves(offset+0w1) in mkEnv(combinedDecs @ makeMoves 0w0, CodeZero (* unit result *)) end | (SOME length, BlockOpEqualByte) => (* Comparing with the null string and up to 3 characters. *) let val {base=baseSrc, index=indexSrc, offset=offsetSrc} = genSrcAddress and {base=baseDst, index=indexDst, offset=offsetDst} = genDstAddress val moveKind = LoadStoreMLByte{isImmutable=false} (* Build andalso tree to check each byte. For the null string this simply returns "true". *) fun makeComparison offset = if offset = length then CodeTrue else Cond( Binary{oper=WordComparison{test=TestEqual, isSigned=false}, arg1=LoadOperation{kind=moveKind, address={base=baseSrc, index=indexSrc, offset=offsetSrc+offset*multiplier}}, arg2=LoadOperation{kind=moveKind, address={base=baseDst, index=indexDst, offset=offsetDst+offset*multiplier}}}, makeComparison(offset+0w1), CodeFalse) in mkEnv(combinedDecs, makeComparison 0w0) end | _ => mkEnv(combinedDecs, BlockOperation{kind=kind, sourceLeft=genSrcAddress, destRight=genDstAddress, length=genLength}) in SOME operation end | simpGeneral (context as {enterAddr, nextAddress, ...}) (Handle{exp, handler, exPacketAddr}) = let (* We need to make a new binding for the exception packet. *) val expBody = simplify(exp, context) val newAddr = nextAddress() val () = enterAddr(exPacketAddr, (EnvGenLoad(LoadLocal newAddr), EnvSpecNone)) val handleBody = simplify(handler, context) in SOME(Handle{exp=expBody, handler=handleBody, exPacketAddr=newAddr}) end | simpGeneral _ _ = NONE (* Where we have an Indirect or Eval we want the argument as either a tuple or an inline function respectively if that's possible. Getting that also involves various other cases as well. Because a binding may later be used in such a context we treat any binding in that way as well. *) and simpSpecial (Extract ext, { lookupAddr, ...}, tailDecs) = let val (gen, spec) = lookupAddr ext in (envGeneralToCodetree gen, tailDecs, spec) end | simpSpecial (Newenv envArgs, context, tailDecs) = simpNewenv(envArgs, context, tailDecs) | simpSpecial (Lambda lambda, context, tailDecs) = let val (gen, spec) = simpLambda(lambda, context, NONE, NONE) in (Lambda gen, tailDecs, spec) end | simpSpecial (Eval {function, argList, resultType}, context, tailDecs) = simpFunctionCall(function, argList, resultType, context, tailDecs) | simpSpecial (Unary{oper, arg1}, context, tailDecs) = simpUnary(oper, arg1, context, tailDecs) | simpSpecial (Binary{oper, arg1, arg2}, context, tailDecs) = simpBinary(oper, arg1, arg2, context, tailDecs) | simpSpecial (Arbitrary{oper=ArbCompare test, shortCond, arg1, arg2, longCall}, context, tailDecs) = simpArbitraryCompare(test, shortCond, arg1, arg2, longCall, context, tailDecs) | simpSpecial (Arbitrary{oper=ArbArith arith, shortCond, arg1, arg2, longCall}, context, tailDecs) = simpArbitraryArith(arith, shortCond, arg1, arg2, longCall, context, tailDecs) | simpSpecial (AllocateWordMemory{numWords, flags, initial}, context, tailDecs) = simpAllocateWordMemory(numWords, flags, initial, context, tailDecs) | simpSpecial (Cond(condTest, condThen, condElse), context, tailDecs) = simpIfThenElse(condTest, condThen, condElse, context, tailDecs) | simpSpecial (Tuple { fields, isVariant }, context, tailDecs) = simpTuple(fields, isVariant, context, tailDecs) | simpSpecial (Indirect{ base, offset, indKind }, context, tailDecs) = simpFieldSelect(base, offset, indKind, context, tailDecs) | simpSpecial (c: codetree, s: simpContext, tailDecs): codetree * revlist * envSpecial = let (* Anything else - copy it and then split it into the fields. *) fun split(Newenv(l, e), RevList tailDecs) = (* Pull off bindings. *) split (e, RevList(List.rev l @ tailDecs)) | split(Constnt(m, p), tailDecs) = (Constnt(m, p), tailDecs, findInline p) | split(c, tailDecs) = (c, tailDecs, EnvSpecNone) in split(simplify(c, s), tailDecs) end (* Process a Newenv. We need to add the bindings to the context. *) and simpNewenv((envDecs: codeBinding list, envExp), context as { enterAddr, nextAddress, reprocess, ...}, tailDecs): codetree * revlist * envSpecial = let fun copyDecs ([], decs) = simpSpecial(envExp, context, decs) (* End of the list - process the result expression. *) | copyDecs ((Declar{addr, value, ...} :: vs), decs) = ( case simpSpecial(value, context, decs) of (* If this raises an exception stop here. *) vBinding as (Raise _, _, _) => vBinding | vBinding => let (* Add the declaration to the table. *) val (optV, dec) = makeNewDecl(vBinding, context) val () = enterAddr(addr, optV) in copyDecs(vs, dec) end ) | copyDecs(NullBinding v :: vs, decs) = (* Not a binding - process this and the rest.*) ( case simpSpecial(v, context, decs) of (* If this raises an exception stop here. *) vBinding as (Raise _, _, _) => vBinding | (cGen, RevList cDecs, _) => copyDecs(vs, RevList(NullBinding cGen :: cDecs)) ) | copyDecs(RecDecs mutuals :: vs, RevList decs) = (* Mutually recursive declarations. Any of the declarations may refer to any of the others. They should all be lambdas. The front end generates functions with more than one argument (either curried or tupled) as pairs of mutually recursive functions. The main function body takes its arguments on the stack (or in registers) and the auxiliary inline function, possibly nested, takes the tupled or curried arguments and calls it. If the main function is recursive it will first call the inline function which is why the pair are mutually recursive. As far as possible we want to use the main function since that uses the least memory. Specifically, if the function recurses we want the recursive call to pass all the arguments if it can. *) let (* Reorder the function so the explicitly-inlined ones come first. Their code can then be inserted into the main functions. *) local val (inlines, nonInlines) = List.partition ( fn {lambda = { isInline=Inline, ...}, ... } => true | _ => false) mutuals in val orderedDecs = inlines @ nonInlines end (* Go down the functions creating new addresses for them and entering them in the table. *) val addresses = map (fn {addr, ... } => let val decAddr = nextAddress() in enterAddr (addr, (EnvGenLoad(LoadLocal decAddr), EnvSpecNone)); decAddr end) orderedDecs fun processFunction({ lambda, addr, ... }, newAddr) = let val (gen, spec) = simpLambda(lambda, context, SOME addr, SOME newAddr) (* Update the entry in the table to include any inlineable function. *) val () = enterAddr (addr, (EnvGenLoad (LoadLocal newAddr), spec)) in {addr=newAddr, lambda=gen, use=[]} end val rlist = ListPair.map processFunction (orderedDecs, addresses) in (* and put these declarations onto the list. *) copyDecs(vs, RevList(List.rev(partitionMutualBindings(RecDecs rlist)) @ decs)) end | copyDecs (Container{addr, size, setter, ...} :: vs, RevList decs) = let (* Enter the new address immediately - it's needed in the setter. *) val decAddr = nextAddress() val () = enterAddr (addr, (EnvGenLoad(LoadLocal decAddr), EnvSpecNone)) val (setGen, RevList setDecs, _) = simpSpecial(setter, context, RevList []) in (* If we have inline expanded a function that sets the container we're better off eliminating the container completely. *) case setGen of SetContainer { tuple, filter, container } => let (* Check the container we're setting is the address we've made for it. *) val _ = (case container of Extract(LoadLocal a) => a = decAddr | _ => false) orelse raise InternalError "copyDecs: Container/SetContainer" val newDecAddr = nextAddress() val () = enterAddr (addr, (EnvGenLoad(LoadLocal newDecAddr), EnvSpecNone)) val tupleAddr = nextAddress() val tupleDec = Declar{addr=tupleAddr, use=[], value=tuple} val tupleLoad = mkLoadLocal tupleAddr val resultTuple = BoolVector.foldri(fn (i, true, l) => mkInd(i, tupleLoad) :: l | (_, false, l) => l) [] filter val _ = List.length resultTuple = size orelse raise InternalError "copyDecs: Container/SetContainer size" val containerDec = Declar{addr=newDecAddr, use=[], value=mkTuple resultTuple} (* TODO: We're replacing a container with what is notionally a tuple on the heap. It should be optimised away as a result of a further pass but we currently have indirections from a container for these. On the native platforms that doesn't matter but on 32-in-64 indirecting from the heap and from the stack are different. *) val _ = reprocess := true in copyDecs(vs, RevList(containerDec :: tupleDec :: setDecs @ decs)) end | _ => let (* The setDecs could refer the container itself if we've optimised this with simpPostSetContainer so we must include them within the setter and not lift them out. *) val dec = Container{addr=decAddr, use=[], size=size, setter=mkEnv(List.rev setDecs, setGen)} in copyDecs(vs, RevList(dec :: decs)) end end in copyDecs(envDecs, tailDecs) end (* Prepares a binding for entry into a look-up table. Returns the entry to put into the table together with any bindings that must be made. If the general part of the optVal is a constant we can just put the constant in the table. If it is a load (Extract) it is just renaming an existing entry so we can return it. Otherwise we have to make a new binding and return a load (Extract) entry for it. *) and makeNewDecl((Constnt w, RevList decs, spec), _) = ((EnvGenConst w, spec), RevList decs) (* No need to create a binding for a constant. *) | makeNewDecl((Extract ext, RevList decs, spec), _) = ((EnvGenLoad ext, spec), RevList decs) (* Binding is simply giving a new name to a variable - can ignore this declaration. *) | makeNewDecl((gen, RevList decs, spec), { nextAddress, ...}) = let (* Create a binding for this value. *) val newAddr = nextAddress() in ((EnvGenLoad(LoadLocal newAddr), spec), RevList(mkDec(newAddr, gen) :: decs)) end and simpLambda({body, isInline, name, argTypes, resultType, closure, localCount, ...}, { lookupAddr, reprocess, maxInlineSize, ... }, myOldAddrOpt, myNewAddrOpt) = let (* A new table for the new function. *) val oldAddrTab = Array.array (localCount, NONE) val optClosureList = makeClosure() val isNowRecursive = ref false local fun localOldAddr (LoadLocal addr) = valOf(Array.sub(oldAddrTab, addr)) | localOldAddr (ext as LoadArgument _) = (EnvGenLoad ext, EnvSpecNone) | localOldAddr (ext as LoadRecursive) = (EnvGenLoad ext, EnvSpecNone) | localOldAddr (LoadClosure addr) = let val oldEntry = List.nth(closure, addr) (* If the entry in the closure is our own address this is recursive. *) fun isRecursive(EnvGenLoad(LoadLocal a), SOME b) = if a = b then (isNowRecursive := true; true) else false | isRecursive _ = false in if isRecursive(EnvGenLoad oldEntry, myOldAddrOpt) then (EnvGenLoad LoadRecursive, EnvSpecNone) else let val newEntry = lookupAddr oldEntry val makeClosure = addToClosure optClosureList fun convertResult(genEntry, specEntry) = (* If after looking up the entry we get our new address it's recursive. *) if isRecursive(genEntry, myNewAddrOpt) then (EnvGenLoad LoadRecursive, EnvSpecNone) else let val newGeneral = case genEntry of EnvGenLoad ext => EnvGenLoad(makeClosure ext) | EnvGenConst w => EnvGenConst w (* Have to modify the environment here so that if we look up free variables we add them to the closure. *) fun convertEnv env args = convertResult(env args) val newSpecial = case specEntry of EnvSpecTuple(size, env) => EnvSpecTuple(size, convertEnv env) | EnvSpecInlineFunction(spec, env) => EnvSpecInlineFunction(spec, convertEnv env) | EnvSpecUnary _ => EnvSpecNone (* Don't pass this in *) | EnvSpecBinary _ => EnvSpecNone (* Don't pass this in *) | EnvSpecNone => EnvSpecNone in (newGeneral, newSpecial) end in convertResult newEntry end end and setTab (index, v) = Array.update (oldAddrTab, index, SOME v) in val newAddressAllocator = ref 0 fun mkAddr () = ! newAddressAllocator before newAddressAllocator := ! newAddressAllocator + 1 val newCode = simplify (body, { enterAddr = setTab, lookupAddr = localOldAddr, nextAddress=mkAddr, reprocess = reprocess, maxInlineSize = maxInlineSize }) end val closureAfterOpt = extractClosure optClosureList val localCount = ! newAddressAllocator (* If we have mutually recursive "small" functions we may turn them into recursive functions. We have to remove the "small" status from them to prevent them from being expanded inline anywhere else. The optimiser may turn them back into "small" functions if the recursion is actually tail-recursion. *) val isNowInline = case isInline of Inline => if ! isNowRecursive then NonInline else Inline | NonInline => NonInline (* Clean up the function body at this point if it could be inlined. There are examples where failing to do this can blow up. This can be the result of creating both a general and special function inside an inline function. *) val cleanBody = case isNowInline of NonInline => newCode | _ => REMOVE_REDUNDANT.cleanProc(newCode, [UseExport], LoadClosure, localCount) val copiedLambda: lambdaForm = { body = cleanBody, isInline = isNowInline, name = name, closure = closureAfterOpt, argTypes = argTypes, resultType = resultType, localCount = localCount, recUse = [] } (* The optimiser checks the size of a function and decides whether it can be inlined. However if we have expanded some other inlines inside the body it may now be too big. In some cases we can get exponential blow-up. We check here that the body is still small enough before allowing it to be used inline. *) val inlineCode = if isNowInline = Inline andalso - evaluateInlining(cleanBody, List.length argTypes, FixedInt.toInt maxInlineSize) <> TooBig + evaluateInlining(cleanBody, List.length argTypes, maxInlineSize) <> TooBig then EnvSpecInlineFunction(copiedLambda, fn addr => (EnvGenLoad(List.nth(closureAfterOpt, addr)), EnvSpecNone)) else EnvSpecNone in ( copiedLambda, inlineCode ) end and simpFunctionCall(function, argList, resultType, context as { reprocess, maxInlineSize, ...}, tailDecs) = let (* Function call - This may involve inlining the function. *) (* Get the function to be called and see if it is inline or a lambda expression. *) val (genFunct, decsFunct, specFunct) = simpSpecial(function, context, tailDecs) (* We have to make a special check here that we are not passing in the function we are trying to expand. This could result in an infinitely recursive expansion. It is only going to happen in very special circumstances such as a definition of the Y combinator. If we see that we don't attempt to expand inline. It could be embedded in a tuple or the closure of a function as well as passed directly. *) val isRecursiveArg = case function of Extract extOrig => let fun containsFunction(Extract thisArg, v) = (v orelse thisArg = extOrig, FOLD_DESCEND) | containsFunction(Lambda{closure, ...}, v) = (* Only the closure, not the body *) (foldl (fn (c, w) => foldtree containsFunction w (Extract c)) v closure, FOLD_DONT_DESCEND) | containsFunction(Eval _, v) = (v, FOLD_DONT_DESCEND) (* OK if it's called *) | containsFunction(_, v) = (v, FOLD_DESCEND) in List.exists(fn (c, _) => foldtree containsFunction false c) argList end | _ => false in case (specFunct, genFunct, isRecursiveArg) of (EnvSpecInlineFunction({body=lambdaBody, localCount, argTypes, ...}, functEnv), _, false) => let val _ = List.length argTypes = List.length argList orelse raise InternalError "simpFunctionCall: argument mismatch" val () = reprocess := true (* If we expand inline we have to reprocess *) and { nextAddress, reprocess, ...} = context (* Expand a function inline, either one marked explicitly to be inlined or one detected as "small". *) (* Calling inline proc or a lambda expression which is just called. The function is replaced with a block containing declarations of the parameters. We need a new table here because the addresses we use to index it are the addresses which are local to the function. New addresses are created in the range of the surrounding function. *) val localVec = Array.array(localCount, NONE) local fun processArgs([], bindings) = ([], bindings) | processArgs((arg, _)::args, bindings) = let val (thisArg, newBindings) = makeNewDecl(simpSpecial(arg, context, bindings), context) val (otherArgs, resBindings) = processArgs(args, newBindings) in (thisArg::otherArgs, resBindings) end val (params, bindings) = processArgs(argList, decsFunct) val paramVec = Vector.fromList params in fun getParameter n = Vector.sub(paramVec, n) (* Bindings necessary for the arguments *) val copiedArgs = bindings end local fun localOldAddr(LoadLocal addr) = valOf(Array.sub(localVec, addr)) | localOldAddr(LoadArgument addr) = getParameter addr | localOldAddr(LoadClosure closureEntry) = functEnv closureEntry | localOldAddr LoadRecursive = raise InternalError "localOldAddr: LoadRecursive" fun setTabForInline (index, v) = Array.update (localVec, index, SOME v) val lambdaContext = { lookupAddr=localOldAddr, enterAddr=setTabForInline, nextAddress=nextAddress, reprocess = reprocess, maxInlineSize = maxInlineSize } in val (cGen, cDecs, cSpec) = simpSpecial(lambdaBody,lambdaContext, copiedArgs) end in (cGen, cDecs, cSpec) end | (_, gen as Constnt _, _) => (* Not inlinable - constant function. *) let val copiedArgs = map (fn (arg, argType) => (simplify(arg, context), argType)) argList val evCopiedCode = Eval {function = gen, argList = copiedArgs, resultType=resultType} in (evCopiedCode, decsFunct, EnvSpecNone) end | (_, gen, _) => (* Anything else. *) let val copiedArgs = map (fn (arg, argType) => (simplify(arg, context), argType)) argList val evCopiedCode = Eval {function = gen, argList = copiedArgs, resultType=resultType} in (evCopiedCode, decsFunct, EnvSpecNone) end end (* Special processing for the current builtIn1 operations. *) (* Constant folding for built-ins. These ought to be type-correct i.e. we should have tagged values in some cases and addresses in others. However there may be run-time tests that would ensure type-correctness and we can't be sure that they will always be folded at compile-time. e.g. we may have if isShort c then shortOp c else longOp c If c is a constant then we may try to fold both the shortOp and the longOp and one of these will be type-incorrect although never executed at run-time. *) and simpUnary(oper, arg1, context as { reprocess, ...}, tailDecs) = let val (genArg1, decArg1, specArg1) = simpSpecial(arg1, context, tailDecs) in case (oper, genArg1) of (NotBoolean, Constnt(v, _)) => ( reprocess := true; (if isShort v andalso toShort v = 0w0 then CodeTrue else CodeFalse, decArg1, EnvSpecNone) ) | (NotBoolean, genArg1) => ( (* NotBoolean: This can be the result of using Bool.not but more usually occurs as a result of other code. We don't have TestNotEqual or IsAddress so both of these use NotBoolean with TestEqual and IsTagged. Also we can insert a NotBoolean as a result of a Cond. We try to eliminate not(not a) and to push other NotBooleans down to a point where a boolean is tested. *) case specArg1 of EnvSpecUnary(NotBoolean, originalArg) => ( (* not(not a) - Eliminate. *) reprocess := true; (originalArg, decArg1, EnvSpecNone) ) | _ => (* Otherwise pass this on. It is also extracted in a Cond. *) (Unary{oper=NotBoolean, arg1=genArg1}, decArg1, EnvSpecUnary(NotBoolean, genArg1)) ) | (IsTaggedValue, Constnt(v, _)) => ( reprocess := true; (if isShort v then CodeTrue else CodeFalse, decArg1, EnvSpecNone) ) | (IsTaggedValue, genArg1) => ( (* We use this to test for nil values and if we have constructed a record (or possibly a function) it can't be null. *) case specArg1 of EnvSpecTuple _ => (CodeFalse, decArg1, EnvSpecNone) before reprocess := true | EnvSpecInlineFunction _ => (CodeFalse, decArg1, EnvSpecNone) before reprocess := true | _ => (Unary{oper=oper, arg1=genArg1}, decArg1, EnvSpecNone) ) | (MemoryCellLength, Constnt(v, _)) => ( reprocess := true; (if isShort v then CodeZero else Constnt(toMachineWord(Address.length(toAddress v)), []), decArg1, EnvSpecNone) ) | (MemoryCellFlags, Constnt(v, _)) => ( reprocess := true; (if isShort v then CodeZero else Constnt(toMachineWord(Address.flags(toAddress v)), []), decArg1, EnvSpecNone) ) | (LongWordToTagged, Constnt(v, _)) => ( reprocess := true; (Constnt(toMachineWord(Word.fromLargeWord(RunCall.unsafeCast v)), []), decArg1, EnvSpecNone) ) | (LongWordToTagged, genArg1) => ( (* If we apply LongWordToTagged to an argument we have created with UnsignedToLongWord we can return the original argument. *) case specArg1 of EnvSpecUnary(UnsignedToLongWord, originalArg) => ( reprocess := true; (originalArg, decArg1, EnvSpecNone) ) | _ => (Unary{oper=LongWordToTagged, arg1=genArg1}, decArg1, EnvSpecNone) ) | (SignedToLongWord, Constnt(v, _)) => ( reprocess := true; (Constnt(toMachineWord(Word.toLargeWordX(RunCall.unsafeCast v)), []), decArg1, EnvSpecNone) ) | (UnsignedToLongWord, Constnt(v, _)) => ( reprocess := true; (Constnt(toMachineWord(Word.toLargeWord(RunCall.unsafeCast v)), []), decArg1, EnvSpecNone) ) | (UnsignedToLongWord, genArg1) => (* Add the operation as the special entry. It can then be recognised by LongWordToTagged. *) (Unary{oper=oper, arg1=genArg1}, decArg1, EnvSpecUnary(UnsignedToLongWord, genArg1)) | _ => (Unary{oper=oper, arg1=genArg1}, decArg1, EnvSpecNone) end and simpBinary(oper, arg1, arg2, context as {reprocess, ...}, tailDecs) = let val (genArg1, decArg1, _ (*specArg1*)) = simpSpecial(arg1, context, tailDecs) val (genArg2, decArgs, _ (*specArg2*)) = simpSpecial(arg2, context, decArg1) in case (oper, genArg1, genArg2) of (WordComparison{test, isSigned}, Constnt(v1, _), Constnt(v2, _)) => if not(isShort v1) orelse not(isShort v2) (* E.g. arbitrary precision on unreachable path. *) then (Binary{oper=oper, arg1=genArg1, arg2=genArg2}, decArgs, EnvSpecNone) else let val () = reprocess := true val testResult = case (test, isSigned) of (* TestEqual can be applied to addresses. *) (TestEqual, _) => toShort v1 = toShort v2 | (TestLess, false) => toShort v1 < toShort v2 | (TestLessEqual, false) => toShort v1 <= toShort v2 | (TestGreater, false) => toShort v1 > toShort v2 | (TestGreaterEqual, false) => toShort v1 >= toShort v2 | (TestLess, true) => toFix v1 < toFix v2 | (TestLessEqual, true) => toFix v1 <= toFix v2 | (TestGreater, true) => toFix v1 > toFix v2 | (TestGreaterEqual, true) => toFix v1 >= toFix v2 | (TestUnordered, _) => raise InternalError "WordComparison: TestUnordered" in (if testResult then CodeTrue else CodeFalse, decArgs, EnvSpecNone) end | (PointerEq, Constnt(v1, _), Constnt(v2, _)) => ( reprocess := true; (if RunCall.pointerEq(v1, v2) then CodeTrue else CodeFalse, decArgs, EnvSpecNone) ) | (FixedPrecisionArith arithOp, Constnt(v1, _), Constnt(v2, _)) => if not(isShort v1) orelse not(isShort v2) then (Binary{oper=oper, arg1=genArg1, arg2=genArg2}, decArgs, EnvSpecNone) else let val () = reprocess := true val v1S = toFix v1 and v2S = toFix v2 fun asConstnt v = Constnt(toMachineWord v, []) val raiseOverflow = Raise(Constnt(toMachineWord Overflow, [])) val raiseDiv = Raise(Constnt(toMachineWord Div, [])) (* ?? There's usually an explicit test. *) val resultCode = case arithOp of ArithAdd => (asConstnt(v1S+v2S) handle Overflow => raiseOverflow) | ArithSub => (asConstnt(v1S-v2S) handle Overflow => raiseOverflow) | ArithMult => (asConstnt(v1S*v2S) handle Overflow => raiseOverflow) | ArithQuot => (asConstnt(FixedInt.quot(v1S,v2S)) handle Overflow => raiseOverflow | Div => raiseDiv) | ArithRem => (asConstnt(FixedInt.rem(v1S,v2S)) handle Overflow => raiseOverflow | Div => raiseDiv) | ArithDiv => (asConstnt(FixedInt.div(v1S,v2S)) handle Overflow => raiseOverflow | Div => raiseDiv) | ArithMod => (asConstnt(FixedInt.mod(v1S,v2S)) handle Overflow => raiseOverflow | Div => raiseDiv) in (resultCode, decArgs, EnvSpecNone) end (* Addition and subtraction of zero. These can arise as a result of inline expansion of more general functions. *) | (FixedPrecisionArith ArithAdd, arg1, Constnt(v2, _)) => if isShort v2 andalso toShort v2 = 0w0 then (arg1, decArgs, EnvSpecNone) else (Binary{oper=oper, arg1=genArg1, arg2=genArg2}, decArgs, EnvSpecNone) | (FixedPrecisionArith ArithAdd, Constnt(v1, _), arg2) => if isShort v1 andalso toShort v1 = 0w0 then (arg2, decArgs, EnvSpecNone) else (Binary{oper=oper, arg1=genArg1, arg2=genArg2}, decArgs, EnvSpecNone) | (FixedPrecisionArith ArithSub, arg1, Constnt(v2, _)) => if isShort v2 andalso toShort v2 = 0w0 then (arg1, decArgs, EnvSpecNone) else (Binary{oper=oper, arg1=genArg1, arg2=genArg2}, decArgs, EnvSpecNone) | (WordArith arithOp, Constnt(v1, _), Constnt(v2, _)) => if not(isShort v1) orelse not(isShort v2) then (Binary{oper=oper, arg1=genArg1, arg2=genArg2}, decArgs, EnvSpecNone) else let val () = reprocess := true val v1S = toShort v1 and v2S = toShort v2 fun asConstnt v = Constnt(toMachineWord v, []) val resultCode = case arithOp of ArithAdd => asConstnt(v1S+v2S) | ArithSub => asConstnt(v1S-v2S) | ArithMult => asConstnt(v1S*v2S) | ArithQuot => raise InternalError "WordArith: ArithQuot" | ArithRem => raise InternalError "WordArith: ArithRem" | ArithDiv => asConstnt(v1S div v2S) | ArithMod => asConstnt(v1S mod v2S) in (resultCode, decArgs, EnvSpecNone) end | (WordArith ArithAdd, arg1, Constnt(v2, _)) => if isShort v2 andalso toShort v2 = 0w0 then (arg1, decArgs, EnvSpecNone) else (Binary{oper=oper, arg1=genArg1, arg2=genArg2}, decArgs, EnvSpecNone) | (WordArith ArithAdd, Constnt(v1, _), arg2) => if isShort v1 andalso toShort v1 = 0w0 then (arg2, decArgs, EnvSpecNone) else (Binary{oper=oper, arg1=genArg1, arg2=genArg2}, decArgs, EnvSpecNone) | (WordArith ArithSub, arg1, Constnt(v2, _)) => if isShort v2 andalso toShort v2 = 0w0 then (arg1, decArgs, EnvSpecNone) else (Binary{oper=oper, arg1=genArg1, arg2=genArg2}, decArgs, EnvSpecNone) | (WordLogical logOp, Constnt(v1, _), Constnt(v2, _)) => if not(isShort v1) orelse not(isShort v2) then (Binary{oper=oper, arg1=genArg1, arg2=genArg2}, decArgs, EnvSpecNone) else let val () = reprocess := true val v1S = toShort v1 and v2S = toShort v2 fun asConstnt v = Constnt(toMachineWord v, []) val resultCode = case logOp of LogicalAnd => asConstnt(Word.andb(v1S,v2S)) | LogicalOr => asConstnt(Word.orb(v1S,v2S)) | LogicalXor => asConstnt(Word.xorb(v1S,v2S)) in (resultCode, decArgs, EnvSpecNone) end | (WordLogical logop, arg1, Constnt(v2, _)) => (* Return the zero if we are anding with zero otherwise the original arg *) if isShort v2 andalso toShort v2 = 0w0 then (case logop of LogicalAnd => CodeZero | _ => arg1, decArgs, EnvSpecNone) else (Binary{oper=oper, arg1=genArg1, arg2=genArg2}, decArgs, EnvSpecNone) | (WordLogical logop, Constnt(v1, _), arg2) => if isShort v1 andalso toShort v1 = 0w0 then (case logop of LogicalAnd => CodeZero | _ => arg2, decArgs, EnvSpecNone) else (Binary{oper=oper, arg1=genArg1, arg2=genArg2}, decArgs, EnvSpecNone) (* TODO: Constant folding of shifts. *) | _ => (Binary{oper=oper, arg1=genArg1, arg2=genArg2}, decArgs, EnvSpecNone) end (* Arbitrary precision operations. This is a sort of mixture of a built-in and a conditional. *) and simpArbitraryCompare(TestEqual, _, _, _, _, _, _) = (* We no longer generate this for equality. General equality for arbitrary precision uses a combination of PointerEq and byte comparison. *) raise InternalError "simpArbitraryCompare: TestEqual" | simpArbitraryCompare(test, shortCond, arg1, arg2, longCall, context as {reprocess, ...}, tailDecs) = let val (genCond, decCond, _ (*specArg1*)) = simpSpecial(shortCond, context, tailDecs) val (genArg1, decArg1, _ (*specArg1*)) = simpSpecial(arg1, context, decCond) val (genArg2, decArgs, _ (*specArg2*)) = simpSpecial(arg2, context, decArg1) val posFlags = Address.F_bytes and negFlags = Word8.orb(Address.F_bytes, Address.F_negative) in (* Fold any constant/constant operations but more importantly, if we have variable/constant operations where the constant is short we can avoid using the full arbitrary precision call by just looking at the sign bit. *) case (genCond, genArg1, genArg2) of (_, Constnt(v1, _), Constnt(v2, _)) => let val a1: LargeInt.int = RunCall.unsafeCast v1 and a2: LargeInt.int = RunCall.unsafeCast v2 val testResult = case test of TestLess => a1 < a2 | TestGreater => a1 > a2 | TestLessEqual => a1 <= a2 | TestGreaterEqual => a1 >= a2 | _ => raise InternalError "simpArbitraryCompare: Unimplemented function" in (if testResult then CodeTrue else CodeFalse, decArgs, EnvSpecNone) end | (Constnt(c1, _), _, _) => (* The condition is "isShort X andalso isShort Y". This will have been reduced to a constant false or true if either (a) either argument is long or (b) both arguments are short.*) if isShort c1 andalso toShort c1 = 0w0 then (* One argument is definitely long - generate the long form. *) (simplify(longCall, context), decArgs, EnvSpecNone) else (* Both arguments are short. That should mean they're constants. *) (Binary{oper=WordComparison{test=test, isSigned=true}, arg1=genArg1, arg2=genArg2}, decArgs, EnvSpecNone) before reprocess := true | (_, genArg1, cArg2 as Constnt _) => let (* The constant must be short otherwise the test would be false. *) val isNeg = case test of TestLess => true | TestLessEqual => true | _ => false (* Translate i < c into if isShort i then toShort i < c else isNegative i *) val newCode = Cond(Unary{oper=BuiltIns.IsTaggedValue, arg1=genArg1}, Binary { oper = BuiltIns.WordComparison{test=test, isSigned=true}, arg1 = genArg1, arg2 = cArg2 }, Binary { oper = BuiltIns.WordComparison{test=TestEqual, isSigned=false}, arg1=Unary { oper = MemoryCellFlags, arg1=genArg1 }, arg2=Constnt(toMachineWord(if isNeg then negFlags else posFlags), [])} ) in (newCode, decArgs, EnvSpecNone) end | (_, cArg1 as Constnt _, genArg2) => let (* We're testing c < i so the test is if isShort i then c < toShort i else isPositive i *) val isPos = case test of TestLess => true | TestLessEqual => true | _ => false val newCode = Cond(Unary{oper=BuiltIns.IsTaggedValue, arg1=genArg2}, Binary { oper = BuiltIns.WordComparison{test=test, isSigned=true}, arg1 = cArg1, arg2 = genArg2 }, Binary { oper = BuiltIns.WordComparison{test=TestEqual, isSigned=false}, arg1=Unary { oper = MemoryCellFlags, arg1=genArg2 }, arg2=Constnt(toMachineWord(if isPos then posFlags else negFlags), [])} ) in (newCode, decArgs, EnvSpecNone) end | _ => (Arbitrary{oper=ArbCompare test, shortCond=genCond, arg1=genArg1, arg2=genArg2, longCall=simplify(longCall, context)}, decArgs, EnvSpecNone) end and simpArbitraryArith(arith, shortCond, arg1, arg2, longCall, context, tailDecs) = let (* arg1 and arg2 are the arguments. shortCond is the condition that must be satisfied in order to use the short precision operation i.e. each argument must be short. *) val (genCond, decCond, _ (*specArg1*)) = simpSpecial(shortCond, context, tailDecs) val (genArg1, decArg1, _ (*specArg1*)) = simpSpecial(arg1, context, decCond) val (genArg2, decArgs, _ (*specArg2*)) = simpSpecial(arg2, context, decArg1) in case (genArg1, genArg2, genCond) of (Constnt(v1, _), Constnt(v2, _), _) => let val a1: LargeInt.int = RunCall.unsafeCast v1 and a2: LargeInt.int = RunCall.unsafeCast v2 (*val _ = print ("Fold arbitrary precision: " ^ PolyML.makestring(arith, a1, a2) ^ "\n")*) in case arith of ArithAdd => (Constnt(toMachineWord(a1+a2), []), decArgs, EnvSpecNone) | ArithSub => (Constnt(toMachineWord(a1-a2), []), decArgs, EnvSpecNone) | ArithMult => (Constnt(toMachineWord(a1*a2), []), decArgs, EnvSpecNone) | _ => raise InternalError "simpArbitraryArith: Unimplemented function" end | (_, _, Constnt(c1, _)) => if isShort c1 andalso toShort c1 = 0w0 then (* One argument is definitely long - generate the long form. *) (simplify(longCall, context), decArgs, EnvSpecNone) else (Arbitrary{oper=ArbArith arith, shortCond=genCond, arg1=genArg1, arg2=genArg2, longCall=simplify(longCall, context)}, decArgs, EnvSpecNone) | _ => (Arbitrary{oper=ArbArith arith, shortCond=genCond, arg1=genArg1, arg2=genArg2, longCall=simplify(longCall, context)}, decArgs, EnvSpecNone) end and simpAllocateWordMemory(numWords, flags, initial, context, tailDecs) = let val (genArg1, decArg1, _ (*specArg1*)) = simpSpecial(numWords, context, tailDecs) val (genArg2, decArg2, _ (*specArg2*)) = simpSpecial(flags, context, decArg1) val (genArg3, decArg3, _ (*specArg3*)) = simpSpecial(initial, context, decArg2) in (AllocateWordMemory{numWords=genArg1, flags=genArg2, initial=genArg3}, decArg3, EnvSpecNone) end (* Loads, stores and block operations use address values. The index value is initially an arbitrary code tree but we can recognise common cases of constant index values or where a constant has been added to the index. TODO: If these are C memory moves we can also look at the base address. The base address for C memory operations is a LargeWord.word value i.e. the address is contained in a box. The base addresses for ML memory moves is an ML address i.e. unboxed. *) and simpAddress({base, index=NONE, offset}, _, context) = let val (genBase, decBase, _ (*specBase*)) = simpSpecial(base, context, RevList[]) in ({base=genBase, index=NONE, offset=offset}, decBase) end | simpAddress({base, index=SOME index, offset}, multiplier, context) = let val (genBase, RevList decBase, _) = simpSpecial(base, context, RevList[]) val (genIndex, RevList decIndex, _ (* specIndex *)) = simpSpecial(index, context, RevList[]) val (newIndex, newOffset) = case genIndex of Constnt(indexOffset, _) => (* Convert small, positive offsets but leave large values as indexes. We could have silly index values here which will never be executed because of a range check but should still compile. *) if isShort indexOffset andalso toShort indexOffset < 0w1000 then (NONE, offset + toShort indexOffset * multiplier) else (SOME genIndex, offset) | _ => (SOME genIndex, offset) in ({base=genBase, index=newIndex, offset=newOffset}, RevList(decIndex @ decBase)) end (* (* A built-in function. We can call certain built-ins immediately if the arguments are constants. *) and simpBuiltIn(rtsCallNo, argList, context as { reprocess, ...}) = let val copiedArgs = map (fn arg => simpSpecial(arg, context)) argList open RuntimeCalls (* When checking for a constant we need to check that there are no bindings. They could have side-effects. *) fun isAConstant(Constnt _, [], _) = true | isAConstant _ = false in (* If the function is an RTS call that is safe to evaluate immediately and all the arguments are constants evaluate it now. *) if earlyRtsCall rtsCallNo andalso List.all isAConstant copiedArgs then let val () = reprocess := true exception Interrupt = Thread.Thread.Interrupt (* Turn the arguments into a vector. *) val argVector = case makeConstVal(mkTuple(List.map specialToGeneral copiedArgs)) of Constnt(w, _) => w | _ => raise InternalError "makeConstVal: Not constant" (* Call the function. If it raises an exception (e.g. divide by zero) generate code to raise the exception at run-time. We don't do that for Interrupt which we assume only arises by user interaction and not as a result of executing the code so we reraise that exception immediately. *) val ioOp : int -> machineWord = RunCall.run_call1 RuntimeCalls.POLY_SYS_io_operation (* We need callcode_tupled here because we pass the arguments as a tuple but the RTS functions we're calling expect arguments in registers or on the stack. *) val call: (address * machineWord) -> machineWord = RunCall.run_call1 RuntimeCalls.POLY_SYS_callcode_tupled val code = Constnt (call(toAddress(ioOp rtsCallNo), argVector), []) handle exn as Interrupt => raise exn (* Must not handle this *) | exn => Raise (Constnt(toMachineWord exn, [])) in (code, [], EnvSpecNone) end (* We can optimise certain built-ins in combination with others. If we have POLY_SYS_unsigned_to_longword combined with POLY_SYS_longword_to_tagged we can eliminate both. This can occur in cases such as Word.fromLargeWord o Word8.toLargeWord. If we have POLY_SYS_cmem_load_X functions where the address is formed by adding a constant to an address we can move the addend into the load instruction. *) (* TODO: Could we also have POLY_SYS_signed_to_longword here? *) else if rtsCallNo = POLY_SYS_longword_to_tagged andalso (case copiedArgs of [(_, _, EnvSpecBuiltIn(r, _))] => r = POLY_SYS_unsigned_to_longword | _ => false) then let val arg = (* Get the argument of the argument. *) case copiedArgs of [(_, _, EnvSpecBuiltIn(_, [arg]))] => arg | _ => raise Bind in (arg, [], EnvSpecNone) end else if (rtsCallNo = POLY_SYS_cmem_load_8 orelse rtsCallNo = POLY_SYS_cmem_load_16 orelse rtsCallNo = POLY_SYS_cmem_load_32 orelse rtsCallNo = POLY_SYS_cmem_load_64 orelse rtsCallNo = POLY_SYS_cmem_store_8 orelse rtsCallNo = POLY_SYS_cmem_store_16 orelse rtsCallNo = POLY_SYS_cmem_store_32 orelse rtsCallNo = POLY_SYS_cmem_store_64) andalso (* Check if the first argument is an addition. The second should be a constant. If the addend is a constant it will be a large integer i.e. the address of a byte segment. *) let (* Check that we have a valid value to add to a large word. The cmem_load/store values sign extend their arguments so we use toLargeWordX here. *) fun isAcceptableOffset c = if isShort c (* Shouldn't occur. *) then false else let val l: LargeWord.word = RunCall.unsafeCast c in Word.toLargeWordX(Word.fromLargeWord l) = l end in case copiedArgs of (_, _, EnvSpecBuiltIn(r, args)) :: (Constnt _, _, _) :: _ => r = POLY_SYS_plus_longword andalso (case args of (* If they were both constants we'd have folded them. *) [Constnt(c, _), _] => isAcceptableOffset c | [_, Constnt(c, _)] => isAcceptableOffset c | _ => false) | _ => false end then let (* We have a load or store with an added constant. *) val (base, offset) = case copiedArgs of (_, _, EnvSpecBuiltIn(_, [Constnt(offset, _), base])) :: (Constnt(existing, _), _, _) :: _ => (base, Word.fromLargeWord(RunCall.unsafeCast offset) + toShort existing) | (_, _, EnvSpecBuiltIn(_, [base, Constnt(offset, _)])) :: (Constnt(existing, _), _, _) :: _ => (base, Word.fromLargeWord(RunCall.unsafeCast offset) + toShort existing) | _ => raise Bind val newDecs = List.map(fn h => makeNewDecl(h, context)) copiedArgs val genArgs = List.map(fn ((g, _), _) => envGeneralToCodetree g) newDecs val preDecs = List.foldr (op @) [] (List.map #2 newDecs) val gen = BuiltIn(rtsCallNo, base :: Constnt(toMachineWord offset, []) :: List.drop(genArgs, 2)) in (gen, preDecs, EnvSpecNone) end else let (* Create bindings for the arguments. This ensures that any side-effects in the evaluation of the arguments are performed in the correct order even if the application of the built-in itself is applicative. The new arguments are either loads or constants which are applicative. *) val newDecs = List.map(fn h => makeNewDecl(h, context)) copiedArgs val genArgs = List.map(fn ((g, _), _) => envGeneralToCodetree g) newDecs val preDecs = List.foldr (op @) [] (List.map #2 newDecs) val gen = BuiltIn(rtsCallNo, genArgs) val spec = if reorderable gen then EnvSpecBuiltIn(rtsCallNo, genArgs) else EnvSpecNone in (gen, preDecs, spec) end end *) and simpIfThenElse(condTest, condThen, condElse, context, tailDecs) = (* If-then-else. The main simplification is if we have constants in the test or in both the arms. *) let val word0 = toMachineWord 0 val word1 = toMachineWord 1 val False = word0 val True = word1 in case simpSpecial(condTest, context, tailDecs) of (* If the test is a constant we can return the appropriate arm and ignore the other. *) (Constnt(testResult, _), bindings, _) => let val arm = if wordEq (testResult, False) (* false - return else-part *) then condElse (* if false then x else y == y *) (* if true then x else y == x *) else condThen in simpSpecial(arm, context, bindings) end | (testGen, testbindings as RevList testBList, testSpec) => let fun mkNot (Unary{oper=BuiltIns.NotBoolean, arg1}) = arg1 | mkNot arg = Unary{oper=BuiltIns.NotBoolean, arg1=arg} (* If the test involves a variable that was created with a NOT it's better to move it in here. *) val testCond = case testSpec of EnvSpecUnary(BuiltIns.NotBoolean, arg1) => mkNot arg1 | _ => testGen in case (simpSpecial(condThen, context, RevList[]), simpSpecial(condElse, context, RevList[])) of ((thenConst as Constnt(thenVal, _), RevList [], _), (elseConst as Constnt(elseVal, _), RevList [], _)) => (* Both arms return constants. This situation can arise in situations where we have andalso/orelse where the second "argument" has been reduced to a constant. *) if wordEq (thenVal, elseVal) then (* If the test has a side-effect we have to do it otherwise we can remove it. If we're in a nested andalso/orelse that may mean we can simplify the next level out. *) (thenConst (* or elseConst *), if sideEffectFree testCond then testbindings else RevList(NullBinding testCond :: testBList), EnvSpecNone) (* if x then true else false == x *) else if wordEq (thenVal, True) andalso wordEq (elseVal, False) then (testCond, testbindings, EnvSpecNone) (* if x then false else true == not x *) else if wordEq (thenVal, False) andalso wordEq (elseVal, True) then (mkNot testCond, testbindings, EnvSpecNone) else (* can't optimise *) (Cond (testCond, thenConst, elseConst), testbindings, EnvSpecNone) (* Rewrite "if x then raise y else z" into "(if x then raise y else (); z)" The advantage is that any tuples in z are lifted outside the "if". *) | (thenPart as (Raise _, _:revlist, _), (elsePart, RevList elseBindings, elseSpec)) => (* then-part raises an exception *) (elsePart, RevList(elseBindings @ NullBinding(Cond (testCond, specialToGeneral thenPart, CodeZero)) :: testBList), elseSpec) | ((thenPart, RevList thenBindings, thenSpec), elsePart as (Raise _, _, _)) => (* else part raises an exception *) (thenPart, RevList(thenBindings @ NullBinding(Cond (testCond, CodeZero, specialToGeneral elsePart)) :: testBList), thenSpec) | (thenPart, elsePart) => (Cond (testCond, specialToGeneral thenPart, specialToGeneral elsePart), testbindings, EnvSpecNone) end end (* Tuple construction. Tuples are also used for datatypes and structures (i.e. modules) *) and simpTuple(entries, isVariant, context, tailDecs) = (* The main reason for optimising record constructions is that they appear as tuples in ML. We try to ensure that loads from locally created tuples do not involve indirecting from the tuple but can get the value which was put into the tuple directly. If that is successful we may find that the tuple is never used directly so the use-count mechanism will ensure it is never created. *) let val tupleSize = List.length entries (* The record construction is treated as a block of local declarations so that any expressions which might have side-effects are done exactly once. *) (* We thread the bindings through here to avoid having to append the result. *) fun processFields([], bindings) = ([], bindings) | processFields(field::fields, bindings) = let val (thisField, newBindings) = makeNewDecl(simpSpecial(field, context, bindings), context) val (otherFields, resBindings) = processFields(fields, newBindings) in (thisField::otherFields, resBindings) end val (fieldEntries, allBindings) = processFields(entries, tailDecs) (* Make sure we include any inline code in the result. If this tuple is being "exported" we will lose the "special" part. *) fun envResToCodetree(EnvGenLoad(ext), _) = Extract ext | envResToCodetree(EnvGenConst(w, p), s) = Constnt(w, setInline s p) val generalFields = List.map envResToCodetree fieldEntries val genRec = if List.all isConstnt generalFields then makeConstVal(Tuple{ fields = generalFields, isVariant = isVariant }) else Tuple{ fields = generalFields, isVariant = isVariant } (* Get the field from the tuple if possible. If it's a variant, though, we may try to get an invalid field. See Tests/Succeed/Test167. *) fun getField addr = if addr < tupleSize then List.nth(fieldEntries, addr) else if isVariant then (EnvGenConst(toMachineWord 0, []), EnvSpecNone) else raise InternalError "getField - invalid index" val specRec = EnvSpecTuple(tupleSize, getField) in (genRec, allBindings, specRec) end and simpFieldSelect(base, offset, indKind, context, tailDecs) = let val (genSource, decSource, specSource) = simpSpecial(base, context, tailDecs) in (* Try to do the selection now if possible. *) case specSource of EnvSpecTuple(_, recEnv) => let (* The "special" entry we've found is a tuple. That means that we are taking a field from a tuple we made earlier and so we should be able to get the original code we used when we made the tuple. That might mean the tuple is never used and we can optimise away the construction of it completely. *) val (newGen, newSpec) = recEnv offset in (envGeneralToCodetree newGen, decSource, newSpec) end | _ => (* No special case possible. If the tuple is a constant mkInd/mkVarField will do the selection immediately. *) let val genSelect = case indKind of IndTuple => mkInd(offset, genSource) | IndVariant => mkVarField(offset, genSource) | IndContainer => mkIndContainer(offset, genSource) in (genSelect, decSource, EnvSpecNone) end end (* Process a SetContainer. Unlike the other simpXXX functions this is called after the arguments have been processed. We try to push the SetContainer to the leaves of the expression. This is particularly important with tail-recursive functions that return tuples. Without this the function will lose tail-recursion since each recursion will be followed by code to copy the result back to the previous container. *) and simpPostSetContainer(container, Tuple{fields, ...}, RevList tupleDecs, filter) = let (* Apply the filter now. *) fun select(n, hd::tl) = if n >= BoolVector.length filter then [] else if BoolVector.sub(filter, n) then hd :: select(n+1, tl) else select(n+1, tl) | select(_, []) = [] val selected = select(0, fields) (* Frequently we will have produced an indirection from the same base. These will all be bindings so we have to reverse the process. *) fun findOriginal a = List.find(fn Declar{addr, ...} => addr = a | _ => false) tupleDecs fun checkFields(last, Extract(LoadLocal a) :: tl) = ( case findOriginal a of SOME(Declar{value=Indirect{base=Extract ext, indKind=IndContainer, offset, ...}, ...}) => ( case last of NONE => checkFields(SOME(ext, [offset]), tl) | SOME(lastExt, offsets) => (* It has to be the same base and with increasing offsets (no reordering). *) if lastExt = ext andalso offset > hd offsets then checkFields(SOME(ext, offset :: offsets), tl) else NONE ) | _ => NONE ) | checkFields(_, _ :: _) = NONE | checkFields(last, []) = last fun fieldsToFilter fields = let val maxDest = List.foldl Int.max ~1 fields val filterArray = BoolArray.array(maxDest+1, false) val _ = List.app(fn n => BoolArray.update(filterArray, n, true)) fields in BoolArray.vector filterArray end in case checkFields(NONE, selected) of SOME (ext, fields) => (* It may be a container. *) let val filter = fieldsToFilter fields in case ext of LoadLocal localAddr => let (* Is this a container? If it is and we're copying all of it we can replace the inner container with a binding to the outer. We have to be careful because it is possible that we may create and set the inner container, then have some bindings that do some side-effects with the inner container before then copying it to the outer container. For simplicity and to maintain the condition that the container is set in the tails we only merge the containers if it's at the end (after any "filtering"). *) val allSet = BoolVector.foldl (fn (a, t) => a andalso t) true filter fun findContainer [] = NONE | findContainer (Declar{value, ...} :: tl) = if sideEffectFree value then findContainer tl else NONE | findContainer (Container{addr, size, setter, ...} :: tl) = if localAddr = addr andalso size = BoolVector.length filter andalso allSet then SOME (setter, tl) else NONE | findContainer _ = NONE in case findContainer tupleDecs of SOME (setter, decs) => (* Put in a binding for the inner container address so the setter will set the outer container. For this to work all loads from the stack must use native word length. *) mkEnv(List.rev(Declar{addr=localAddr, value=container, use=[]} :: decs), setter) | NONE => mkEnv(List.rev tupleDecs, SetContainer{container=container, tuple = mkTuple selected, filter=BoolVector.tabulate(List.length selected, fn _ => true)}) end | _ => mkEnv(List.rev tupleDecs, SetContainer{container=container, tuple = mkTuple selected, filter=BoolVector.tabulate(List.length selected, fn _ => true)}) end | NONE => mkEnv(List.rev tupleDecs, SetContainer{container=container, tuple = mkTuple selected, filter=BoolVector.tabulate(List.length selected, fn _ => true)}) end | simpPostSetContainer(container, Cond(ifpt, thenpt, elsept), RevList tupleDecs, filter) = mkEnv(List.rev tupleDecs, Cond(ifpt, simpPostSetContainer(container, thenpt, RevList [], filter), simpPostSetContainer(container, elsept, RevList [], filter))) | simpPostSetContainer(container, Newenv(envDecs, envExp), RevList tupleDecs, filter) = simpPostSetContainer(container, envExp, RevList(List.rev envDecs @ tupleDecs), filter) | simpPostSetContainer(container, BeginLoop{loop, arguments}, RevList tupleDecs, filter) = mkEnv(List.rev tupleDecs, BeginLoop{loop = simpPostSetContainer(container, loop, RevList [], filter), arguments=arguments}) | simpPostSetContainer(_, loop as Loop _, RevList tupleDecs, _) = (* If we are inside a BeginLoop we only set the container on leaves that exit the loop. Loop entries will go back to the BeginLoop so we don't add SetContainer nodes. *) mkEnv(List.rev tupleDecs, loop) | simpPostSetContainer(container, Handle{exp, handler, exPacketAddr}, RevList tupleDecs, filter) = mkEnv(List.rev tupleDecs, Handle{ exp = simpPostSetContainer(container, exp, RevList [], filter), handler = simpPostSetContainer(container, handler, RevList [], filter), exPacketAddr = exPacketAddr}) | simpPostSetContainer(container, tupleGen, RevList tupleDecs, filter) = mkEnv(List.rev tupleDecs, mkSetContainer(container, tupleGen, filter)) fun simplifier{code, numLocals, maxInlineSize} = let val localAddressAllocator = ref 0 val addrTab = Array.array(numLocals, NONE) fun lookupAddr (LoadLocal addr) = valOf(Array.sub(addrTab, addr)) | lookupAddr (env as LoadArgument _) = (EnvGenLoad env, EnvSpecNone) | lookupAddr (env as LoadRecursive) = (EnvGenLoad env, EnvSpecNone) | lookupAddr (LoadClosure _) = raise InternalError "top level reached in simplifier" and enterAddr (addr, tab) = Array.update (addrTab, addr, SOME tab) fun mkAddr () = ! localAddressAllocator before localAddressAllocator := ! localAddressAllocator + 1 val reprocess = ref false val (gen, RevList bindings, spec) = simpSpecial(code, {lookupAddr = lookupAddr, enterAddr = enterAddr, nextAddress = mkAddr, reprocess = reprocess, maxInlineSize = maxInlineSize}, RevList[]) in ((gen, List.rev bindings, spec), ! localAddressAllocator, !reprocess) end fun specialToGeneral(g, b as _ :: _, s) = mkEnv(b, specialToGeneral(g, [], s)) | specialToGeneral(Constnt(w, p), [], s) = Constnt(w, setInline s p) | specialToGeneral(g, [], _) = g structure Sharing = struct type codetree = codetree and codeBinding = codeBinding and envSpecial = envSpecial end end;