/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.text_embedding;

import ai.djl.Model;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.Pair;
import java.io.IOException;
import java.lang.reflect.Type;
import java.nio.file.Path;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.opensearch.ml.common.model.BaseModelConfig;
import org.opensearch.ml.engine.algorithms.text_embedding.HuggingfaceTextEmbeddingServingTranslator;
import org.opensearch.ml.engine.algorithms.text_embedding.HuggingfaceTextEmbeddingTranslator;

public class HuggingfaceTextEmbeddingTranslatorFactory
implements TranslatorFactory {
    private static final Set<Pair<Type, Type>> SUPPORTED_TYPES = new HashSet<Pair<Type, Type>>();
    private final BaseModelConfig.PoolingMode poolingMode;
    private boolean normalizeResult;
    private final String modelType;
    private final boolean neuron;

    public HuggingfaceTextEmbeddingTranslatorFactory(BaseModelConfig.PoolingMode poolingMode, boolean normalizeResult, String modelType, boolean neuron) {
        this.poolingMode = poolingMode == null ? BaseModelConfig.PoolingMode.MEAN : poolingMode;
        this.normalizeResult = normalizeResult;
        this.modelType = modelType;
        this.neuron = neuron;
    }

    public Set<Pair<Type, Type>> getSupportedTypes() {
        return SUPPORTED_TYPES;
    }

    public <I, O> Translator<I, O> newInstance(Class<I> input, Class<O> output, Model model, Map<String, ?> arguments) throws TranslateException {
        Path modelPath = model.getModelPath();
        try {
            HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.builder(arguments).optTokenizerPath(modelPath).optManager(model.getNDManager()).build();
            boolean inputTokenTypeIds = this.neuron && ("bert".equalsIgnoreCase(this.modelType) || "albert".equalsIgnoreCase(this.modelType));
            HuggingfaceTextEmbeddingTranslator translator = HuggingfaceTextEmbeddingTranslator.builder(tokenizer, arguments).optPoolingMode(this.poolingMode.getName()).optNormalize(this.normalizeResult).optInputTokenTypeIds(inputTokenTypeIds).build();
            if (input == String.class && output == float[].class) {
                return translator;
            }
            if (input == Input.class && output == Output.class) {
                return new HuggingfaceTextEmbeddingServingTranslator(translator);
            }
            throw new IllegalArgumentException("Unsupported input/output types.");
        }
        catch (IOException e) {
            throw new TranslateException("Failed to load tokenizer.", (Throwable)e);
        }
    }

    static {
        SUPPORTED_TYPES.add((Pair<Type, Type>)new Pair(String.class, float[].class));
        SUPPORTED_TYPES.add((Pair<Type, Type>)new Pair(Input.class, Output.class));
    }
}

