/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.dashscope.rag;

import com.alibaba.cloud.ai.dashscope.rag.OpenSearchConfig;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.aliyun.ha3engine.vector.Client;
import com.aliyun.ha3engine.vector.models.Config;
import com.aliyun.ha3engine.vector.models.PushDocumentsRequest;
import com.aliyun.ha3engine.vector.models.PushDocumentsResponse;
import com.aliyun.ha3engine.vector.models.QueryRequest;
import com.aliyun.ha3engine.vector.models.SearchResponse;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.util.Assert;

public class OpenSearchVector
implements VectorStore {
    private static final Logger logger = LoggerFactory.getLogger(OpenSearchVector.class);
    private static final String ID_FIELD_NAME = "id";
    private static final String CONTENT_FIELD_NAME = "content";
    private static final String METADATA_FIELD_NAME = "metadata";
    private final String tableName;
    private final String pKField;
    private final List<String> outputFields;
    private final OpenSearchClientWrapper openSearchClient;

    public OpenSearchVector(String tableName, OpenSearchConfig openSearchConfig) {
        this(tableName, List.of(CONTENT_FIELD_NAME, METADATA_FIELD_NAME), openSearchConfig);
    }

    public OpenSearchVector(String tableName, List<String> outputFields, OpenSearchConfig openSearchConfig) {
        this.tableName = tableName;
        this.outputFields = outputFields;
        this.pKField = ID_FIELD_NAME;
        try {
            Config config = Config.build(openSearchConfig.toClientParams());
            String instanceId = config.getInstanceId();
            Client client = new Client(config);
            this.openSearchClient = new OpenSearchClientWrapper(client, instanceId, tableName, this.pKField);
        }
        catch (Exception e) {
            logger.error("init OpenSearch client error", (Throwable)e);
            throw new RuntimeException(e);
        }
    }

    public void add(List<Document> documents) {
        for (Document document : documents) {
            ArrayList documentToAdd = new ArrayList();
            HashMap<String, Object> documentMap = new HashMap<String, Object>();
            HashMap<String, String> documentFields = new HashMap<String, String>();
            documentFields.put(ID_FIELD_NAME, document.getId());
            documentFields.put(CONTENT_FIELD_NAME, document.getContent());
            documentFields.put(METADATA_FIELD_NAME, JSON.toJSONString((Object)document.getMetadata()));
            documentMap.put("fields", documentFields);
            documentMap.put("cmd", "add");
            documentToAdd.add(documentMap);
            this.openSearchClient.uploadDocument(documentToAdd);
        }
    }

    public Optional<Boolean> delete(List<String> idList) {
        for (String id : idList) {
            ArrayList documentToDelete = new ArrayList();
            HashMap<String, Object> documentMap = new HashMap<String, Object>();
            HashMap<String, String> documentFields = new HashMap<String, String>();
            documentFields.put(this.pKField, id);
            documentMap.put("fields", documentFields);
            documentMap.put("cmd", "delete");
            documentToDelete.add(documentMap);
            this.openSearchClient.deleteDocument(documentToDelete);
        }
        return Optional.of(true);
    }

    public List<Document> similaritySearch(String query) {
        return this.similaritySearch(SearchRequest.query((String)query));
    }

    public List<Document> similaritySearch(SearchRequest searchRequest) {
        Assert.notNull((Object)searchRequest, (String)"The search request must not be null.");
        double similarityThreshold = searchRequest.getSimilarityThreshold();
        QueryRequest queryRequest = new QueryRequest();
        queryRequest.setTableName(this.tableName);
        queryRequest.setContent(searchRequest.getQuery());
        queryRequest.setModal("text");
        queryRequest.setTopK(Integer.valueOf(searchRequest.getTopK()));
        queryRequest.setOutputFields(this.outputFields);
        try {
            List<SimilarityResult> similarityResults = this.openSearchClient.search(queryRequest);
            return similarityResults.stream().filter(result -> result.score >= similarityThreshold).map(result -> new Document(result.id, result.content, result.metadata)).collect(Collectors.toList());
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public static class OpenSearchClientWrapper {
        private final Client client;
        private final String fullTableName;
        private final String pKField;

        public OpenSearchClientWrapper(Client client, String instanceId, String tableName, String pKField) {
            this.client = client;
            this.pKField = pKField;
            this.fullTableName = instanceId + "_" + tableName;
        }

        public void uploadDocument(List<Map<String, ?>> document) {
            PushDocumentsRequest request = new PushDocumentsRequest();
            request.setBody(document);
            try {
                PushDocumentsResponse response = this.client.pushDocuments(this.fullTableName, this.pKField, request);
                ResponseBody responseBody = new ResponseBody(response.getBody());
                if (!responseBody.isSuccess()) {
                    String errorCode = responseBody.errorCode;
                    String errorMsg = Optional.ofNullable(responseBody.errorMessage).orElse("No error message provided");
                    throw new RuntimeException(String.format("OpenSearch upload Document failed. Error code: %s. Error message: %s", errorCode, errorMsg));
                }
            }
            catch (Exception e) {
                throw new RuntimeException("OpenSearch upload Document failed.Error message:" + e.getMessage(), e);
            }
        }

        public void deleteDocument(List<Map<String, ?>> document) {
            PushDocumentsRequest request = new PushDocumentsRequest();
            request.setBody(document);
            try {
                PushDocumentsResponse response = this.client.pushDocuments(this.fullTableName, this.pKField, request);
                ResponseBody responseBody = new ResponseBody(response.getBody());
                if (!responseBody.isSuccess()) {
                    String errorCode = responseBody.errorCode;
                    String errorMsg = Optional.ofNullable(responseBody.errorMessage).orElse("No error message provided");
                    throw new RuntimeException(String.format("OpenSearch delete Documents failed. Error code: %s. Error message: %s", errorCode, errorMsg));
                }
            }
            catch (Exception e) {
                throw new RuntimeException("OpenSearch delete Documents failed. Error message:" + e.getMessage(), e);
            }
        }

        public List<SimilarityResult> search(QueryRequest queryRequest) {
            try {
                SearchResponse searchResponse = this.client.inferenceQuery(queryRequest);
                SearchResponseBody responseBody = this.getSearchResponseBody(searchResponse);
                return SearchResultParser.parse(responseBody);
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }

        @NotNull
        private SearchResponseBody getSearchResponseBody(SearchResponse searchResponse) {
            SearchResponseBody responseBody = new SearchResponseBody(searchResponse.getBody());
            if (responseBody.hasError()) {
                String errorCode = responseBody.errorCode;
                String errorMsg = Optional.ofNullable(responseBody.errorMessage).orElse("No error message provided");
                throw new RuntimeException(String.format("OpenSearch inferenceQuery failed. Error code: %s. Error message: %s", errorCode, errorMsg));
            }
            return responseBody;
        }
    }

    public record SimilarityResult(String id, double score, String content, Map<String, Object> metadata) {
    }

    private static class SearchResultParser {
        private static final Logger logger = LoggerFactory.getLogger(SearchResultParser.class);
        private static final String FIELDS_KEY = "fields";
        private static final String SCORE_KEY = "score";

        private SearchResultParser() {
        }

        private static List<SimilarityResult> parse(SearchResponseBody responseBody) {
            ArrayList<SimilarityResult> documents = new ArrayList<SimilarityResult>();
            Integer totalCount = responseBody.totalCount;
            if (totalCount == null || totalCount <= 0) {
                return documents;
            }
            JSONArray resultArray = responseBody.result;
            if (resultArray != null && !resultArray.isEmpty()) {
                for (Object item : resultArray) {
                    documents.add(SearchResultParser.parse((JSONObject)item));
                }
            }
            return documents;
        }

        private static SimilarityResult parse(JSONObject jsonDocument) {
            String id = SearchResultParser.extractId(jsonDocument);
            String content = SearchResultParser.extractContent(jsonDocument);
            double score = SearchResultParser.extractScore(jsonDocument);
            Map<String, Object> metadata = SearchResultParser.extractMetadata(jsonDocument);
            return new SimilarityResult(id, score, content, metadata);
        }

        private static String extractContent(JSONObject jsonDocument) {
            if (jsonDocument.containsKey((Object)FIELDS_KEY)) {
                JSONObject fields = jsonDocument.getJSONObject(FIELDS_KEY);
                String content = fields.getString(OpenSearchVector.CONTENT_FIELD_NAME);
                if (content == null || content.isEmpty()) {
                    return "";
                }
                return content;
            }
            return "";
        }

        private static String extractId(JSONObject jsonDocument) {
            String id = jsonDocument.getString(OpenSearchVector.ID_FIELD_NAME);
            if (id == null || id.isEmpty()) {
                return "";
            }
            return id;
        }

        private static double extractScore(JSONObject jsonDocument) {
            return jsonDocument.getDouble(SCORE_KEY);
        }

        private static Map<String, Object> extractMetadata(JSONObject jsonDocument) {
            if (jsonDocument.containsKey((Object)FIELDS_KEY)) {
                JSONObject fields = jsonDocument.getJSONObject(FIELDS_KEY);
                String metadataStr = fields.getString(OpenSearchVector.METADATA_FIELD_NAME);
                return (Map)JSONObject.parseObject((String)metadataStr, HashMap.class);
            }
            return new HashMap<String, Object>();
        }
    }

    private record ResponseBody(Integer code, String status, String errorCode, String errorMessage) {
        private static final String CODE_KEY = "code";
        private static final String STATUS_KEY = "status";
        private static final String ERROR_CODE_KEY = "errorCode";
        private static final String ERROR_MESSAGE_KEY = "errorMsg";
        private static final Integer SUCCESS_CODE = 200;

        public ResponseBody(String pushDocumentsResponseBodyString) {
            this(JSON.parseObject((String)pushDocumentsResponseBodyString));
        }

        public ResponseBody(JSONObject jsonObject) {
            this(jsonObject.getInteger(CODE_KEY), jsonObject.getString(STATUS_KEY), jsonObject.getString(ERROR_CODE_KEY), jsonObject.getString(ERROR_MESSAGE_KEY));
        }

        public boolean isSuccess() {
            return SUCCESS_CODE.equals(this.code);
        }
    }

    private record SearchResponseBody(String errorCode, String errorMessage, Integer totalCount, JSONArray result) {
        private static final String TOTAL_COUNT_KEY = "totalCount";
        private static final String ERROR_CODE_KEY = "errorCode";
        private static final String ERROR_MESSAGE_KEY = "errorMsg";
        private static final String RESULT_KEY = "result";

        public SearchResponseBody(String searchResponseBodyString) {
            this(JSON.parseObject((String)searchResponseBodyString));
        }

        public SearchResponseBody(JSONObject jsonObject) {
            this(jsonObject.getString(ERROR_CODE_KEY), jsonObject.getString(ERROR_MESSAGE_KEY), jsonObject.getInteger(TOTAL_COUNT_KEY), jsonObject.getJSONArray(RESULT_KEY));
        }

        public boolean hasError() {
            return this.errorCode != null && !this.errorCode.isEmpty();
        }
    }
}

