package org.renjin.pipeliner;

import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.renjin.pipeliner.fusion.FusedNode;
import org.renjin.pipeliner.fusion.LoopKernelCache;
import org.renjin.pipeliner.fusion.LoopKernels;
import org.renjin.pipeliner.node.CallNode;
import org.renjin.pipeliner.node.DataNode;
import org.renjin.pipeliner.node.DeferredNode;
import org.renjin.pipeliner.node.FunctionNode;
import org.renjin.pipeliner.node.OutputNode;
import org.renjin.pipeliner.optimize.Optimizers;
import org.renjin.primitives.ni.DeferredNativeCall;
import org.renjin.primitives.ni.NativeOutputVector;
import org.renjin.primitives.vector.DeferredComputation;
import org.renjin.repackaged.guava.base.Preconditions;
import org.renjin.repackaged.guava.collect.HashMultimap;
import org.renjin.repackaged.guava.collect.Lists;
import org.renjin.repackaged.guava.collect.Maps;
import org.renjin.repackaged.guava.collect.Multimap;
import org.renjin.repackaged.guava.collect.Sets;
import org.renjin.sexp.Vector;

/* loaded from: input_file:org/renjin/pipeliner/DeferredGraph.class */
public class DeferredGraph {
    private List<DeferredNode> rootNodes = new ArrayList();
    private List<DeferredNode> nodes = Lists.newArrayList();
    private IdentityHashMap<Vector, DeferredNode> vectorMap = Maps.newIdentityHashMap();
    private IdentityHashMap<DeferredNativeCall, CallNode> callMap = Maps.newIdentityHashMap();
    private Multimap<String, FunctionNode> computationIndex = HashMultimap.create();

    public DeferredGraph(DeferredNativeCall deferredNativeCall) {
        addRoot(deferredNativeCall);
    }

    public DeferredGraph(Vector vector) {
        addRoot(vector);
    }

    public DeferredGraph() {
    }

    public void optimize(LoopKernelCache loopKernelCache) {
        new Optimizers().optimize(this);
        fuse(loopKernelCache);
    }

    public void fuse(LoopKernelCache loopKernelCache) {
        Set<DeferredNode> newIdentityHashSet = Sets.newIdentityHashSet();
        Iterator it = new ArrayList(this.rootNodes).iterator();
        while (it.hasNext()) {
            fuse(loopKernelCache, newIdentityHashSet, (DeferredNode) it.next());
        }
    }

    private void fuse(LoopKernelCache loopKernelCache, Set<DeferredNode> set, DeferredNode deferredNode) {
        if (set.add(deferredNode)) {
            Iterator<DeferredNode> it = deferredNode.getOperands().iterator();
            while (it.hasNext()) {
                fuse(loopKernelCache, set, it.next());
            }
        }
        FusedNode tryFuse = tryFuse(deferredNode);
        if (tryFuse != null) {
            tryFuse.startCompilation(loopKernelCache);
            replaceNode(deferredNode, tryFuse);
        }
    }

    private FusedNode tryFuse(DeferredNode deferredNode) {
        if (LoopKernels.INSTANCE.supports(deferredNode)) {
            return new FusedNode((FunctionNode) deferredNode);
        }
        return null;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void addRoot(Vector vector) {
        this.rootNodes.add(addNode(vector));
    }

    private DeferredNode addNode(Vector vector) {
        DeferredNode addDataNode;
        DeferredNode deferredNode = this.vectorMap.get(vector);
        if (deferredNode != null) {
            return deferredNode;
        }
        if (!vector.isDeferred()) {
            addDataNode = addDataNode(vector);
        } else if (vector instanceof NativeOutputVector) {
            addDataNode = addOutputNode(vector);
        } else {
            if (!(vector instanceof DeferredComputation)) {
                throw new UnsupportedOperationException("deferred: " + vector.getClass().getName());
            }
            addDataNode = addComputeNode((DeferredComputation) vector);
        }
        return addDataNode;
    }

    private DataNode addDataNode(Vector vector) {
        DataNode dataNode = new DataNode(vector);
        this.vectorMap.put(vector, dataNode);
        this.nodes.add(dataNode);
        return dataNode;
    }

    private DeferredNode addComputeNode(DeferredComputation deferredComputation) {
        Vector[] operands = deferredComputation.getOperands();
        DeferredNode[] deferredNodeArr = new DeferredNode[operands.length];
        for (int i = 0; i < operands.length; i++) {
            deferredNodeArr[i] = addNode(operands[i]);
        }
        if (this.computationIndex.containsKey(deferredComputation.getComputationName())) {
            for (FunctionNode functionNode : this.computationIndex.get(deferredComputation.getComputationName())) {
                if (equivalent(deferredNodeArr, functionNode.getOperands())) {
                    return functionNode;
                }
            }
        }
        FunctionNode functionNode2 = new FunctionNode(deferredComputation);
        functionNode2.addInputs(deferredNodeArr);
        this.nodes.add(functionNode2);
        this.vectorMap.put(deferredComputation, functionNode2);
        this.computationIndex.put(deferredComputation.getComputationName(), functionNode2);
        return functionNode2;
    }

    private boolean equivalent(DeferredNode[] deferredNodeArr, List<DeferredNode> list) {
        if (deferredNodeArr.length != list.size()) {
            return false;
        }
        for (int i = 0; i < deferredNodeArr.length; i++) {
            if (!equivalent(deferredNodeArr[i], list.get(i))) {
                return false;
            }
        }
        return true;
    }

    private boolean equivalent(DeferredNode deferredNode, DeferredNode deferredNode2) {
        if (deferredNode == deferredNode2) {
            return true;
        }
        if (deferredNode instanceof DataNode) {
            return ((DataNode) deferredNode).equivalent(deferredNode2);
        }
        return false;
    }

    private DeferredNode addOutputNode(Vector vector) {
        OutputNode outputNode = new OutputNode((NativeOutputVector) vector);
        this.vectorMap.put(vector, outputNode);
        this.nodes.add(outputNode);
        addCallChild(outputNode, ((NativeOutputVector) vector).getCall());
        return outputNode;
    }

    private CallNode addNode(DeferredNativeCall deferredNativeCall) {
        CallNode callNode = this.callMap.get(deferredNativeCall);
        if (callNode != null) {
            return callNode;
        }
        CallNode callNode2 = new CallNode(deferredNativeCall);
        this.nodes.add(callNode2);
        this.callMap.put(deferredNativeCall, callNode2);
        addChildren(callNode2, deferredNativeCall.getOperands());
        return callNode2;
    }

    private void addCallChild(DeferredNode deferredNode, DeferredNativeCall deferredNativeCall) {
        CallNode addNode = addNode(deferredNativeCall);
        deferredNode.addInput(addNode);
        addNode.addOutput(deferredNode);
    }

    private void addRoot(DeferredNativeCall deferredNativeCall) {
        CallNode callNode = new CallNode(deferredNativeCall);
        this.rootNodes.add(callNode);
        this.nodes.add(callNode);
        addChildren(callNode, deferredNativeCall.getOperands());
    }

    private void addChildren(DeferredNode deferredNode, Vector[] vectorArr) {
        for (Vector vector : vectorArr) {
            DeferredNode addNode = addNode(vector);
            deferredNode.addInput(addNode);
            addNode.addOutput(deferredNode);
        }
    }

    public void dumpGraph() {
        try {
            File createTempFile = File.createTempFile("deferred", ".dot");
            PrintWriter printWriter = new PrintWriter(createTempFile);
            printGraph(printWriter);
            printWriter.close();
            System.out.println("Dumping compute graph to " + createTempFile.getAbsolutePath());
        } catch (IOException e) {
        }
    }

    public void printGraph(PrintWriter printWriter) {
        Set<DeferredNode> newIdentityHashSet = Sets.newIdentityHashSet();
        ArrayDeque arrayDeque = new ArrayDeque(this.rootNodes);
        while (!arrayDeque.isEmpty()) {
            DeferredNode deferredNode = (DeferredNode) arrayDeque.poll();
            if (newIdentityHashSet.add(deferredNode)) {
                arrayDeque.addAll(deferredNode.getOperands());
            }
        }
        printWriter.println("digraph G {");
        printEdges(printWriter, newIdentityHashSet);
        printNodes(printWriter, newIdentityHashSet);
        printWriter.println("}");
        printWriter.flush();
    }

    private void printEdges(PrintWriter printWriter, Set<DeferredNode> set) {
        for (DeferredNode deferredNode : set) {
            Iterator<DeferredNode> it = deferredNode.getOperands().iterator();
            while (it.hasNext()) {
                printWriter.println(it.next().getDebugId() + " -> " + deferredNode.getDebugId());
            }
        }
    }

    private void printNodes(PrintWriter printWriter, Set<DeferredNode> set) {
        for (DeferredNode deferredNode : set) {
            printWriter.println(deferredNode.getDebugId() + " [ label=\"" + deferredNode.getDebugLabel() + "\",  shape=\"" + deferredNode.getShape().name().toLowerCase() + "\"]");
        }
    }

    public List<DeferredNode> getRoots() {
        return this.rootNodes;
    }

    public Vector getRootResult(int i) {
        return this.rootNodes.get(i).getVector();
    }

    public DeferredNode getRoot() {
        Preconditions.checkState(this.rootNodes.size() == 1);
        return this.rootNodes.get(0);
    }

    public List<DeferredNode> getNodes() {
        return this.nodes;
    }

    public void replaceNode(DeferredNode deferredNode, DeferredNode deferredNode2) {
        this.nodes.remove(deferredNode);
        if (!this.nodes.contains(deferredNode2)) {
            this.nodes.add(deferredNode2);
        }
        if (this.rootNodes.remove(deferredNode)) {
            this.rootNodes.add(deferredNode2);
        }
        Iterator<DeferredNode> it = deferredNode.getOperands().iterator();
        while (it.hasNext()) {
            it.next().removeUse(deferredNode);
        }
        Iterator<DeferredNode> it2 = deferredNode.getUses().iterator();
        while (it2.hasNext()) {
            it2.next().replaceOperand(deferredNode, deferredNode2);
        }
    }
}
