/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.action;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.admin.cluster.node.stats.NodeStats;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsAction;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequest;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.metrics.CounterMetric;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.ingest.IngestStats;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.utils.InferenceProcessorInfoExtractor;

public class TransportGetTrainedModelsStatsAction
extends HandledTransportAction<GetTrainedModelsStatsAction.Request, GetTrainedModelsStatsAction.Response> {
    private final Client client;
    private final ClusterService clusterService;
    private final TrainedModelProvider trainedModelProvider;

    @Inject
    public TransportGetTrainedModelsStatsAction(TransportService transportService, ActionFilters actionFilters, ClusterService clusterService, TrainedModelProvider trainedModelProvider, Client client) {
        super("cluster:monitor/xpack/ml/inference/stats/get", transportService, actionFilters, GetTrainedModelsStatsAction.Request::new);
        this.client = client;
        this.clusterService = clusterService;
        this.trainedModelProvider = trainedModelProvider;
    }

    protected void doExecute(Task task, GetTrainedModelsStatsAction.Request request, ActionListener<GetTrainedModelsStatsAction.Response> listener) {
        ModelAliasMetadata currentMetadata = ModelAliasMetadata.fromState((ClusterState)this.clusterService.state());
        GetTrainedModelsStatsAction.Response.Builder responseBuilder = new GetTrainedModelsStatsAction.Response.Builder();
        ActionListener inferenceStatsListener = ActionListener.wrap(inferenceStats -> listener.onResponse((Object)responseBuilder.setInferenceStatsByModelId(inferenceStats.stream().collect(Collectors.toMap(InferenceStats::getModelId, Function.identity()))).build()), arg_0 -> listener.onFailure(arg_0));
        ActionListener nodesStatsListener = ActionListener.wrap(nodesStatsResponse -> {
            Set<String> allPossiblePipelineReferences = responseBuilder.getExpandedIdsWithAliases().entrySet().stream().flatMap(entry -> Stream.concat(((Set)entry.getValue()).stream(), Stream.of((String)entry.getKey()))).collect(Collectors.toSet());
            Map<String, Set<String>> pipelineIdsByModelIdsOrAliases = InferenceProcessorInfoExtractor.pipelineIdsByModelIdsOrAliases(this.clusterService.state(), allPossiblePipelineReferences);
            Map<String, IngestStats> modelIdIngestStats = TransportGetTrainedModelsStatsAction.inferenceIngestStatsByModelId(nodesStatsResponse, currentMetadata, pipelineIdsByModelIdsOrAliases);
            responseBuilder.setIngestStatsByModelId(modelIdIngestStats);
            this.trainedModelProvider.getInferenceStats(responseBuilder.getExpandedIdsWithAliases().keySet().toArray(new String[0]), (ActionListener<List<InferenceStats>>)inferenceStatsListener);
        }, arg_0 -> listener.onFailure(arg_0));
        ActionListener idsListener = ActionListener.wrap(tuple -> {
            responseBuilder.setExpandedIdsWithAliases((Map)tuple.v2()).setTotalModelCount(((Long)tuple.v1()).longValue());
            String[] ingestNodes = TransportGetTrainedModelsStatsAction.ingestNodes(this.clusterService.state());
            NodesStatsRequest nodesStatsRequest = new NodesStatsRequest(ingestNodes).clear().addMetric(NodesStatsRequest.Metric.INGEST.metricName());
            ClientHelper.executeAsyncWithOrigin((Client)this.client, (String)"ml", (ActionType)NodesStatsAction.INSTANCE, (ActionRequest)nodesStatsRequest, (ActionListener)nodesStatsListener);
        }, arg_0 -> listener.onFailure(arg_0));
        this.trainedModelProvider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), Collections.emptySet(), currentMetadata, (ActionListener<Tuple<Long, Map<String, Set<String>>>>)idsListener);
    }

    static Map<String, IngestStats> inferenceIngestStatsByModelId(NodesStatsResponse response, ModelAliasMetadata currentMetadata, Map<String, Set<String>> modelIdToPipelineId) {
        HashMap<String, IngestStats> ingestStatsMap = new HashMap<String, IngestStats>();
        Map<String, Set> trueModelIdToPipelines = modelIdToPipelineId.entrySet().stream().collect(Collectors.toMap(entry -> {
            String maybeModelId = currentMetadata.getModelId((String)entry.getKey());
            return maybeModelId == null ? (String)entry.getKey() : maybeModelId;
        }, Map.Entry::getValue, Sets::union));
        trueModelIdToPipelines.forEach((modelId, pipelineIds) -> {
            List<IngestStats> collectedStats = response.getNodes().stream().map(nodeStats -> TransportGetTrainedModelsStatsAction.ingestStatsForPipelineIds(nodeStats, pipelineIds)).collect(Collectors.toList());
            ingestStatsMap.put((String)modelId, TransportGetTrainedModelsStatsAction.mergeStats(collectedStats));
        });
        return ingestStatsMap;
    }

    static String[] ingestNodes(ClusterState clusterState) {
        return clusterState.nodes().getIngestNodes().keySet().toArray(new String[0]);
    }

    static IngestStats ingestStatsForPipelineIds(NodeStats nodeStats, Set<String> pipelineIds) {
        IngestStats fullNodeStats = nodeStats.getIngestStats();
        HashMap filteredProcessorStats = new HashMap(fullNodeStats.getProcessorStats());
        filteredProcessorStats.keySet().retainAll(pipelineIds);
        List<IngestStats.PipelineStat> filteredPipelineStats = fullNodeStats.getPipelineStats().stream().filter(pipelineStat -> pipelineIds.contains(pipelineStat.getPipelineId())).collect(Collectors.toList());
        CounterMetric ingestCount = new CounterMetric();
        CounterMetric ingestTimeInMillis = new CounterMetric();
        CounterMetric ingestCurrent = new CounterMetric();
        CounterMetric ingestFailedCount = new CounterMetric();
        filteredPipelineStats.forEach(pipelineStat -> {
            IngestStats.Stats stats = pipelineStat.getStats();
            ingestCount.inc(stats.getIngestCount());
            ingestTimeInMillis.inc(stats.getIngestTimeInMillis());
            ingestCurrent.inc(stats.getIngestCurrent());
            ingestFailedCount.inc(stats.getIngestFailedCount());
        });
        return new IngestStats(new IngestStats.Stats(ingestCount.count(), ingestTimeInMillis.count(), ingestCurrent.count(), ingestFailedCount.count()), filteredPipelineStats, filteredProcessorStats);
    }

    private static IngestStats mergeStats(List<IngestStats> ingestStatsList) {
        LinkedHashMap<String, IngestStatsAccumulator> pipelineStatsAcc = new LinkedHashMap<String, IngestStatsAccumulator>(ingestStatsList.size());
        LinkedHashMap<String, Map> processorStatsAcc = new LinkedHashMap<String, Map>(ingestStatsList.size());
        IngestStatsAccumulator totalStats = new IngestStatsAccumulator();
        ingestStatsList.forEach(ingestStats -> {
            ingestStats.getPipelineStats().forEach(pipelineStat -> pipelineStatsAcc.computeIfAbsent(pipelineStat.getPipelineId(), p -> new IngestStatsAccumulator()).inc(pipelineStat.getStats()));
            ingestStats.getProcessorStats().forEach((pipelineId, processorStat) -> {
                Map processorAcc = processorStatsAcc.computeIfAbsent((String)pipelineId, k -> new LinkedHashMap());
                processorStat.forEach(p -> processorAcc.computeIfAbsent(p.getName(), k -> new IngestStatsAccumulator(p.getType())).inc(p.getStats()));
            });
            totalStats.inc(ingestStats.getTotalStats());
        });
        ArrayList pipelineStatList = new ArrayList(pipelineStatsAcc.size());
        pipelineStatsAcc.forEach((pipelineId, accumulator) -> pipelineStatList.add(new IngestStats.PipelineStat(pipelineId, accumulator.build())));
        LinkedHashMap processorStatList = new LinkedHashMap(processorStatsAcc.size());
        processorStatsAcc.forEach((pipelineId, accumulatorMap) -> {
            ArrayList processorStats = new ArrayList(accumulatorMap.size());
            accumulatorMap.forEach((processorName, acc) -> processorStats.add(new IngestStats.ProcessorStat(processorName, acc.type, acc.build())));
            processorStatList.put(pipelineId, processorStats);
        });
        return new IngestStats(totalStats.build(), pipelineStatList, processorStatList);
    }

    private static class IngestStatsAccumulator {
        CounterMetric ingestCount = new CounterMetric();
        CounterMetric ingestTimeInMillis = new CounterMetric();
        CounterMetric ingestCurrent = new CounterMetric();
        CounterMetric ingestFailedCount = new CounterMetric();
        String type;

        IngestStatsAccumulator() {
        }

        IngestStatsAccumulator(String type) {
            this.type = type;
        }

        void inc(IngestStats.Stats s) {
            this.ingestCount.inc(s.getIngestCount());
            this.ingestTimeInMillis.inc(s.getIngestTimeInMillis());
            this.ingestCurrent.inc(s.getIngestCurrent());
            this.ingestFailedCount.inc(s.getIngestFailedCount());
        }

        IngestStats.Stats build() {
            return new IngestStats.Stats(this.ingestCount.count(), this.ingestTimeInMillis.count(), this.ingestCurrent.count(), this.ingestFailedCount.count());
        }
    }
}

