package aima.core.probability.bayes.exact;

import aima.core.probability.CategoricalDistribution;
import aima.core.probability.Factor;
import aima.core.probability.RandomVariable;
import aima.core.probability.bayes.BayesInference;
import aima.core.probability.bayes.BayesianNetwork;
import aima.core.probability.bayes.FiniteNode;
import aima.core.probability.bayes.Node;
import aima.core.probability.proposition.AssignmentProposition;
import aima.core.probability.util.ProbabilityTable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

/* loaded from: input_file:aima/core/probability/bayes/exact/EliminationAsk.class */
public class EliminationAsk implements BayesInference {
    private static final ProbabilityTable _identity = new ProbabilityTable(new double[]{1.0d}, new RandomVariable[0]);

    public CategoricalDistribution eliminationAsk(RandomVariable[] randomVariableArr, AssignmentProposition[] assignmentPropositionArr, BayesianNetwork bayesianNetwork) {
        HashSet hashSet = new HashSet();
        ArrayList arrayList = new ArrayList();
        calculateVariables(randomVariableArr, assignmentPropositionArr, bayesianNetwork, hashSet, arrayList);
        List<Factor> arrayList2 = new ArrayList();
        for (RandomVariable randomVariable : order(bayesianNetwork, arrayList)) {
            arrayList2.add(0, makeFactor(randomVariable, assignmentPropositionArr, bayesianNetwork));
            if (hashSet.contains(randomVariable)) {
                arrayList2 = sumOut(randomVariable, arrayList2, bayesianNetwork);
            }
        }
        return ((ProbabilityTable) pointwiseProduct(arrayList2).pointwiseProductPOS(_identity, randomVariableArr)).normalize();
    }

    @Override // aima.core.probability.bayes.BayesInference
    public CategoricalDistribution ask(RandomVariable[] randomVariableArr, AssignmentProposition[] assignmentPropositionArr, BayesianNetwork bayesianNetwork) {
        return eliminationAsk(randomVariableArr, assignmentPropositionArr, bayesianNetwork);
    }

    protected void calculateVariables(RandomVariable[] randomVariableArr, AssignmentProposition[] assignmentPropositionArr, BayesianNetwork bayesianNetwork, Set<RandomVariable> set, Collection<RandomVariable> collection) {
        collection.addAll(bayesianNetwork.getVariablesInTopologicalOrder());
        set.addAll(collection);
        for (RandomVariable randomVariable : randomVariableArr) {
            set.remove(randomVariable);
        }
        for (AssignmentProposition assignmentProposition : assignmentPropositionArr) {
            set.removeAll(assignmentProposition.getScope());
        }
    }

    protected List<RandomVariable> order(BayesianNetwork bayesianNetwork, Collection<RandomVariable> collection) {
        ArrayList arrayList = new ArrayList(collection);
        Collections.reverse(arrayList);
        return arrayList;
    }

    private Factor makeFactor(RandomVariable randomVariable, AssignmentProposition[] assignmentPropositionArr, BayesianNetwork bayesianNetwork) {
        Node node = bayesianNetwork.getNode(randomVariable);
        if (!(node instanceof FiniteNode)) {
            throw new IllegalArgumentException("Elimination-Ask only works with finite Nodes.");
        }
        FiniteNode finiteNode = (FiniteNode) node;
        ArrayList arrayList = new ArrayList();
        for (AssignmentProposition assignmentProposition : assignmentPropositionArr) {
            if (finiteNode.getCPT().contains(assignmentProposition.getTermVariable())) {
                arrayList.add(assignmentProposition);
            }
        }
        return finiteNode.getCPT().getFactorFor((AssignmentProposition[]) arrayList.toArray(new AssignmentProposition[arrayList.size()]));
    }

    private List<Factor> sumOut(RandomVariable randomVariable, List<Factor> list, BayesianNetwork bayesianNetwork) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (Factor factor : list) {
            if (factor.contains(randomVariable)) {
                arrayList2.add(factor);
            } else {
                arrayList.add(factor);
            }
        }
        arrayList.add(pointwiseProduct(arrayList2).sumOut(randomVariable));
        return arrayList;
    }

    private Factor pointwiseProduct(List<Factor> list) {
        Factor factor = list.get(0);
        for (int i = 1; i < list.size(); i++) {
            factor = factor.pointwiseProduct(list.get(i));
        }
        return factor;
    }
}
