/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.nodes.physical.stream;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelWriter;
import org.apache.calcite.rel.SingleRel;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlDescriptorOperator;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.api.config.MLPredictRuntimeConfigOptions;
import org.apache.flink.table.ml.AsyncPredictRuntimeProvider;
import org.apache.flink.table.ml.ModelProvider;
import org.apache.flink.table.ml.PredictRuntimeProvider;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.calcite.RexModelCall;
import org.apache.flink.table.planner.calcite.RexTableArgCall;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNode;
import org.apache.flink.table.planner.plan.nodes.exec.InputProperty;
import org.apache.flink.table.planner.plan.nodes.exec.spec.MLPredictSpec;
import org.apache.flink.table.planner.plan.nodes.exec.spec.ModelSpec;
import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecMLPredictTableFunction;
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalRel;
import org.apache.flink.table.planner.plan.utils.FunctionCallUtil;
import org.apache.flink.table.planner.plan.utils.MLPredictUtil;
import org.apache.flink.table.planner.utils.ShortcutUtils;

public class StreamPhysicalMLPredictTableFunction
extends SingleRel
implements StreamPhysicalRel {
    private final RelDataType outputRowType;
    private final FlinkLogicalTableFunctionScan scan;
    private final Map<String, String> runtimeConfig;

    public StreamPhysicalMLPredictTableFunction(RelOptCluster cluster, RelTraitSet traits, RelNode inputRel, FlinkLogicalTableFunctionScan scan, RelDataType outputRowType, Map<String, String> runtimeConfig) {
        super(cluster, traits, inputRel);
        this.scan = scan;
        this.outputRowType = outputRowType;
        this.runtimeConfig = runtimeConfig;
    }

    @Override
    public RelNode copy(RelTraitSet traitSet, List<RelNode> inputs) {
        return new StreamPhysicalMLPredictTableFunction(this.getCluster(), traitSet, inputs.get(0), this.scan, this.getRowType(), this.runtimeConfig);
    }

    @Override
    public boolean requireWatermark() {
        return false;
    }

    @Override
    public ExecNode<?> translateToExecNode() {
        RexModelCall modelCall = (RexModelCall)this.extractOperand(operand -> operand instanceof RexModelCall);
        return new StreamExecMLPredictTableFunction((ReadableConfig)ShortcutUtils.unwrapTableConfig(this), this.buildMLPredictSpec(this.runtimeConfig), this.buildModelSpec(modelCall), this.buildAsyncOptions(modelCall, this.runtimeConfig), InputProperty.DEFAULT, FlinkTypeFactory.toLogicalRowType(this.getRowType()), this.getRelDetailedDescription());
    }

    @Override
    protected RelDataType deriveRowType() {
        return this.outputRowType;
    }

    @Override
    public RelWriter explainTerms(RelWriter pw) {
        return super.explainTerms(pw).item("invocation", this.scan.getCall()).item("rowType", this.getRowType());
    }

    public RexNode getMLPredictCall() {
        return this.scan.getCall();
    }

    private MLPredictSpec buildMLPredictSpec(Map<String, String> runtimeConfig) {
        RexTableArgCall tableCall = (RexTableArgCall)this.extractOperand(operand -> operand instanceof RexTableArgCall);
        RexCall descriptorCall = (RexCall)this.extractOperand(operand -> operand instanceof RexCall && ((RexCall)operand).getOperator() instanceof SqlDescriptorOperator);
        HashMap<String, Integer> column2Index = new HashMap<String, Integer>();
        List<String> fieldNames = tableCall.getType().getFieldNames();
        for (int i = 0; i < fieldNames.size(); ++i) {
            column2Index.put(fieldNames.get(i), i);
        }
        List<FunctionCallUtil.FunctionParam> features = descriptorCall.getOperands().stream().map(operand -> {
            if (operand instanceof RexLiteral) {
                RexLiteral literal = (RexLiteral)operand;
                String fieldName = RexLiteral.stringValue(literal);
                Integer index = (Integer)column2Index.get(fieldName);
                if (index == null) {
                    throw new TableException(String.format("Field %s is not found in input schema: %s.", fieldName, tableCall.getType()));
                }
                return new FunctionCallUtil.FieldRef(index);
            }
            throw new TableException(String.format("Unknown operand for descriptor operator: %s.", operand));
        }).collect(Collectors.toList());
        return new MLPredictSpec(features, runtimeConfig);
    }

    private ModelSpec buildModelSpec(RexModelCall modelCall) {
        ModelSpec modelSpec = new ModelSpec(modelCall.getContextResolvedModel());
        modelSpec.setModelProvider(modelCall.getModelProvider());
        return modelSpec;
    }

    @Nullable
    private FunctionCallUtil.AsyncOptions buildAsyncOptions(RexModelCall modelCall, Map<String, String> runtimeConfig) {
        boolean isAsyncEnabled = this.isAsyncMLPredict(modelCall.getModelProvider(), runtimeConfig);
        if (isAsyncEnabled) {
            return MLPredictUtil.getMergedMLPredictAsyncOptions(runtimeConfig, ShortcutUtils.unwrapTableConfig(this.getCluster()));
        }
        return null;
    }

    private <T> Optional<T> extractOptionalOperand(Predicate<RexNode> predicate) {
        return ((RexCall)this.scan.getCall()).getOperands().stream().filter(predicate).findFirst();
    }

    private <T> T extractOperand(Predicate<RexNode> predicate) {
        return this.extractOptionalOperand(predicate).orElseThrow(() -> new TableException(String.format("MLPredict doesn't contain specified operand: %s", this.scan.getCall().toString())));
    }

    private boolean isAsyncMLPredict(ModelProvider provider, Map<String, String> runtimeConfig) {
        boolean syncFound = false;
        boolean asyncFound = false;
        Optional requiredMode = Configuration.fromMap(runtimeConfig).getOptional(MLPredictRuntimeConfigOptions.ASYNC);
        if (provider instanceof PredictRuntimeProvider) {
            syncFound = true;
        }
        if (provider instanceof AsyncPredictRuntimeProvider) {
            asyncFound = true;
        }
        if (!syncFound && !asyncFound) {
            throw new TableException(String.format("Unknown model provider found: %s.", provider.getClass().getName()));
        }
        if (requiredMode.isEmpty()) {
            return asyncFound;
        }
        if (((Boolean)requiredMode.get()).booleanValue()) {
            if (!asyncFound) {
                throw new TableException(String.format("Require async mode, but model provider %s doesn't support async mode.", provider.getClass().getName()));
            }
            return true;
        }
        if (!syncFound) {
            throw new TableException(String.format("Require sync mode, but model provider %s doesn't support sync mode.", provider.getClass().getName()));
        }
        return false;
    }
}

