{ Tail recursion optimization Copyright (c) 2006 by Florian Klaempfl This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. **************************************************************************** } unit opttail; {$i fpcdefs.inc} interface uses symdef,node; procedure do_opttail(var n : tnode;p : tprocdef); implementation uses globtype, symconst,symsym, defcmp,defutil, nutils,nbas,nflw,ncal,nld,ncnv,nmem, pass_1, paramgr; procedure do_opttail(var n : tnode;p : tprocdef); var labelnode : tlabelnode; function find_and_replace_tailcalls(var n : tnode) : boolean; var usedcallnode : tcallnode; function has_copyback_paras(call: tcallnode): boolean; var n: tcallparanode; begin n:=tcallparanode(call.left); result:=false; while assigned(n) do begin if assigned(n.fparacopyback) then begin result:=true; exit; end; n:=tcallparanode(n.right); end; end; function is_optimizable_recursivecall(n : tnode) : boolean; begin result:= (n.nodetype=calln) and (tcallnode(n).procdefinition=p) and not(assigned(tcallnode(n).methodpointer)) and not has_copyback_paras(tcallnode(n)); if result then usedcallnode:=tcallnode(n) else { obsolete type cast? } result:=((n.nodetype=typeconvn) and (ttypeconvnode(n).convtype=tc_equal) and is_optimizable_recursivecall(ttypeconvnode(n).left)); end; function is_resultassignment(n : tnode) : boolean; begin result:=((n.nodetype=loadn) and (tloadnode(n).symtableentry=p.funcretsym)) or ((n.nodetype=typeconvn) and (ttypeconvnode(n).convtype=tc_equal) and is_resultassignment(ttypeconvnode(n).left)); end; var calcnodes, copynodes, hp : tnode; nodes, calcstatements, copystatements : tstatementnode; paranode : tcallparanode; tempnode : ttempcreatenode; loadnode : tloadnode; oldnodetree : tnode; useaddr : boolean; begin { no tail call found and replaced so far } result:=false; if n=nil then exit; usedcallnode:=nil; case n.nodetype of statementn: begin hp:=n; { search last node } while assigned(tstatementnode(hp).right) do hp:=tstatementnode(hp).right; result:=find_and_replace_tailcalls(tstatementnode(hp).left); end; ifn: begin result:=find_and_replace_tailcalls(tifnode(n).right); { avoid short bool eval here } result:=find_and_replace_tailcalls(tifnode(n).t1) or result; end; calln, assignn: begin if ((n.nodetype=calln) and is_optimizable_recursivecall(n)) or ((n.nodetype=assignn) and is_resultassignment(tbinarynode(n).left) and is_optimizable_recursivecall(tbinarynode(n).right)) then begin { found one! } { writeln('tail recursion optimization for ',p.mangledname); printnode(output,n); } { create assignments for all parameters } { this is hairy to do because one parameter could be used to calculate another one, so assign them first to temps and then add them } calcnodes:=internalstatements(calcstatements); copynodes:=internalstatements(copystatements); paranode:=tcallparanode(usedcallnode.left); while assigned(paranode) do begin if assigned(paranode.fparainit) then begin addstatement(calcstatements,paranode.fparainit); paranode.fparainit:=nil; end; useaddr:=(paranode.parasym.varspez in [vs_var,vs_constref]) or ((paranode.parasym.varspez=vs_const) and paramanager.push_addr_param(paranode.parasym.varspez,paranode.parasym.vardef,p.proccalloption)) or ((paranode.parasym.varspez=vs_value) and is_open_array(paranode.parasym.vardef)); if useaddr then begin tempnode:=ctempcreatenode.create(voidpointertype,voidpointertype.size,tt_persistent,true); addstatement(calcstatements,tempnode); addstatement(calcstatements, cassignmentnode.create( ctemprefnode.create(tempnode), caddrnode.create_internal(paranode.left) )); end else begin tempnode:=ctempcreatenode.create(paranode.left.resultdef,paranode.left.resultdef.size,tt_persistent,true); addstatement(calcstatements,tempnode); addstatement(calcstatements, cassignmentnode.create_internal( ctemprefnode.create(tempnode), paranode.left )); end; { "cast" away const varspezs } loadnode:=cloadnode.create(paranode.parasym,paranode.parasym.owner); include(tloadnode(loadnode).loadnodeflags,loadnf_isinternal_ignoreconst); { load the address of the symbol instead of symbol } if useaddr then include(tloadnode(loadnode).loadnodeflags,loadnf_load_addr); addstatement(copystatements, cassignmentnode.create_internal( loadnode, ctemprefnode.create(tempnode) )); addstatement(copystatements,ctempdeletenode.create_normal_temp(tempnode)); { reused } paranode.left:=nil; paranode:=tcallparanode(paranode.right); end; oldnodetree:=n; n:=internalstatements(nodes); if assigned(usedcallnode.callinitblock) then begin addstatement(nodes,usedcallnode.callinitblock); usedcallnode.callinitblock:=nil; end; addstatement(nodes,calcnodes); addstatement(nodes,copynodes); { create goto } addstatement(nodes,cgotonode.create(labelnode.labsym)); if assigned(usedcallnode.callcleanupblock) then begin { callcleanupblock should contain only temp. node clean up } checktreenodetypes(usedcallnode.callcleanupblock, [tempdeleten,blockn,statementn,temprefn,nothingn]); addstatement(nodes,usedcallnode.callcleanupblock); usedcallnode.callcleanupblock:=nil; end; oldnodetree.free; do_firstpass(n); result:=true; end; end; blockn: result:=find_and_replace_tailcalls(tblocknode(n).left); else ; end; end; var s : tstatementnode; oldnodes : tnode; i : longint; labelsym : tlabelsym; begin { check if the parameters actually would support tail recursion elimination } for i:=0 to p.paras.count-1 do with tparavarsym(p.paras[i]) do if (varspez=vs_out) or { parameters requiring tables are too complicated to handle and slow down things anyways so a tail recursion call makes no sense } is_managed_type(vardef) then exit; labelsym:=clabelsym.create('$opttail'); labelnode:=clabelnode.create(cnothingnode.create,labelsym); if find_and_replace_tailcalls(n) then begin oldnodes:=n; n:=internalstatements(s); addstatement(s,labelnode); addstatement(s,oldnodes); end else labelnode.free; end; end.