package org.molgenis.data.semanticsearch.service.impl;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.Iterables;
import com.google.common.collect.LinkedHashMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import java.text.DecimalFormat;
import java.text.DecimalFormatSymbols;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import org.apache.commons.lang3.StringUtils;
import org.apache.lucene.queryparser.classic.QueryParser;
import org.apache.lucene.util.packed.PackedInts;
import org.elasticsearch.common.base.Joiner;
import org.molgenis.data.QueryRule;
import org.molgenis.data.semanticsearch.service.QueryExpansionService;
import org.molgenis.data.semanticsearch.service.bean.SearchParam;
import org.molgenis.data.semanticsearch.service.bean.TagGroup;
import org.molgenis.data.semanticsearch.utils.SemanticSearchServiceUtils;
import org.molgenis.ontology.core.model.OntologyTerm;
import org.molgenis.ontology.core.service.OntologyService;
import org.molgenis.ontology.ic.TermFrequencyService;
import org.molgenis.ontology.utils.Stemmer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;

/* loaded from: input_file:WEB-INF/lib/molgenis-data-semanticsearch-2.0.0-SNAPSHOT.jar:org/molgenis/data/semanticsearch/service/impl/QueryExpansionServiceImpl.class */
public class QueryExpansionServiceImpl implements QueryExpansionService {
    private static final float LEXICAL_QUERY_BOOSTVALUE = 1.0f;
    private final TermFrequencyService termFrequencyService;
    private final OntologyService ontologyService;
    private static final String CARET_CHARACTER = "^";
    private static final String ESCAPED_CARET_CHARACTER = "\\^";
    private Joiner termJoiner = Joiner.on(' ');
    private LoadingCache<OntologyTerm, List<String>> cachedOntologyTermQuery = CacheBuilder.newBuilder().maximumSize(1000).expireAfterWrite(1, TimeUnit.HOURS).build(new CacheLoader<OntologyTerm, List<String>>() { // from class: org.molgenis.data.semanticsearch.service.impl.QueryExpansionServiceImpl.1
        @Override // com.google.common.cache.CacheLoader
        public List<String> load(OntologyTerm ontologyTerm) {
            return QueryExpansionServiceImpl.this.getExpandedQueriesFromOntologyTerm(ontologyTerm);
        }
    });
    private static final Logger LOG = LoggerFactory.getLogger((Class<?>) QueryExpansionServiceImpl.class);
    private static DecimalFormat DECIMAL_FORMATTER = new DecimalFormat("#.#####", new DecimalFormatSymbols(Locale.US));

    @Autowired
    public QueryExpansionServiceImpl(OntologyService ontologyService, TermFrequencyService termFrequencyService) {
        this.ontologyService = (OntologyService) Objects.requireNonNull(ontologyService);
        this.termFrequencyService = (TermFrequencyService) Objects.requireNonNull(termFrequencyService);
    }

    @Override // org.molgenis.data.semanticsearch.service.QueryExpansionService
    public QueryRule expand(SearchParam searchParam) {
        QueryRule createDisMaxQueryRuleForTerms;
        ArrayList arrayList = new ArrayList();
        Set<String> lexicalQueries = searchParam.getLexicalQueries();
        List<TagGroup> tagGroups = searchParam.getTagGroups();
        if (lexicalQueries != null && !lexicalQueries.isEmpty() && (createDisMaxQueryRuleForTerms = createDisMaxQueryRuleForTerms((List) lexicalQueries.stream().filter((v0) -> {
            return StringUtils.isNotBlank(v0);
        }).map(this::parseQueryString).map(this::boostLexicalQuery).collect(Collectors.toList()), Float.valueOf(1.0f))) != null) {
            arrayList.add(createDisMaxQueryRuleForTerms);
        }
        if (searchParam.isSemanticSearchEnabled()) {
            LinkedHashMultimap create = LinkedHashMultimap.create();
            tagGroups.forEach(tagGroup -> {
                create.put(tagGroup.getMatchedWords(), tagGroup);
            });
            Iterator it = create.keySet().iterator();
            while (it.hasNext()) {
                QueryRule createQueryRuleForOntologyTerms = createQueryRuleForOntologyTerms(Lists.newArrayList(create.get((LinkedHashMultimap) it.next())));
                if (createQueryRuleForOntologyTerms != null) {
                    arrayList.add(createQueryRuleForOntologyTerms);
                }
            }
        }
        QueryRule queryRule = null;
        if (arrayList.size() > 0) {
            queryRule = new QueryRule(arrayList);
            queryRule.setOperator(QueryRule.Operator.DIS_MAX);
        }
        return queryRule;
    }

    QueryRule createQueryRuleForOntologyTerms(List<TagGroup> list) {
        QueryRule queryRule = null;
        if (list.size() > 0) {
            float score = list.get(0).getScore();
            Multimap<OntologyTerm, OntologyTerm> groupAtomicOntologyTermsBySynonym = groupAtomicOntologyTermsBySynonym(list);
            Set<OntologyTerm> keySet = groupAtomicOntologyTermsBySynonym.keySet();
            if (keySet.size() > 1) {
                Map<OntologyTerm, Float> normalizeBoostValueForOntologyTermGroup = normalizeBoostValueForOntologyTermGroup(groupAtomicOntologyTermsBySynonym);
                queryRule = createShouldQueryRule((List) keySet.stream().map(ontologyTerm -> {
                    return createDisMaxQueryRuleForTerms((List) groupAtomicOntologyTermsBySynonym.get(ontologyTerm).stream().flatMap(ontologyTerm -> {
                        return getCachedQueriesForOntologyTerm(ontologyTerm).stream();
                    }).collect(Collectors.toList()), (Float) normalizeBoostValueForOntologyTermGroup.get(ontologyTerm));
                }).collect(Collectors.toList()), Float.valueOf(score));
            } else {
                queryRule = createDisMaxQueryRuleForTerms((List) groupAtomicOntologyTermsBySynonym.get((OntologyTerm) Iterables.get(keySet, 0)).stream().flatMap(ontologyTerm2 -> {
                    return getCachedQueriesForOntologyTerm(ontologyTerm2).stream();
                }).collect(Collectors.toList()), Float.valueOf(score));
            }
        }
        return queryRule;
    }

    List<String> getCachedQueriesForOntologyTerm(OntologyTerm ontologyTerm) {
        try {
            return this.cachedOntologyTermQuery.get(ontologyTerm);
        } catch (ExecutionException e) {
            LOG.error(e.getMessage());
            return Collections.emptyList();
        }
    }

    List<String> getExpandedQueriesFromOntologyTerm(OntologyTerm ontologyTerm) {
        List<String> list = (List) SemanticSearchServiceUtils.getLowerCaseTerms(ontologyTerm).stream().map(this::parseQueryString).collect(Collectors.toList());
        Function function = ontologyTerm2 -> {
            return SemanticSearchServiceUtils.getLowerCaseTerms(ontologyTerm2).stream().map(str -> {
                return parseBoostQueryString(str, Math.pow(0.5d, this.ontologyService.getOntologyTermDistance(ontologyTerm, ontologyTerm2).intValue()));
            });
        };
        LOG.trace("Started retrieving the children for the OntologyTerm: {}", ontologyTerm.toString());
        List list2 = (List) StreamSupport.stream(this.ontologyService.getChildren(ontologyTerm, 3).spliterator(), false).flatMap(function).collect(Collectors.toList());
        LOG.trace("Retrieved {}", Integer.valueOf(list2.size()));
        list.addAll(list2);
        return list;
    }

    private Map<OntologyTerm, Float> normalizeBoostValueForOntologyTermGroup(Multimap<OntologyTerm, OntologyTerm> multimap) {
        Map map = (Map) multimap.keySet().stream().collect(Collectors.toMap(ontologyTerm -> {
            return ontologyTerm;
        }, ontologyTerm2 -> {
            return Double.valueOf(multimap.get(ontologyTerm2).stream().map(SemanticSearchServiceUtils::getLowerCaseTerms).map(this::getBestInverseDocumentFrequency).mapToDouble(f -> {
                return f.floatValue();
            }).max().orElse(1.0d));
        }));
        double orElse = map.values().stream().mapToDouble((v0) -> {
            return v0.doubleValue();
        }).max().orElse(1.0d);
        return (Map) map.entrySet().stream().collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, entry -> {
            return new Float(((Double) entry.getValue()).doubleValue() / orElse);
        }));
    }

    private Multimap<OntologyTerm, OntologyTerm> groupAtomicOntologyTermsBySynonym(List<TagGroup> list) {
        LinkedHashMultimap create = LinkedHashMultimap.create();
        list.get(0).getOntologyTerms().forEach(ontologyTerm -> {
            create.put(ontologyTerm, ontologyTerm);
        });
        list.stream().skip(1L).flatMap(tagGroup -> {
            return tagGroup.getOntologyTerms().stream();
        }).filter(ontologyTerm2 -> {
            return !create.containsKey(ontologyTerm2);
        }).forEach(ontologyTerm3 -> {
            OntologyTerm ontologyTerm3 = (OntologyTerm) create.keySet().stream().filter(ontologyTerm4 -> {
                return hasSameSynonyms(ontologyTerm4, ontologyTerm3);
            }).findFirst().orElse(null);
            if (ontologyTerm3 != null) {
                create.put(ontologyTerm3, ontologyTerm3);
            }
        });
        return create;
    }

    private boolean hasSameSynonyms(OntologyTerm ontologyTerm, OntologyTerm ontologyTerm2) {
        List list = (List) SemanticSearchServiceUtils.getLowerCaseTerms(ontologyTerm).stream().map(Stemmer::cleanStemPhrase).collect(Collectors.toList());
        return SemanticSearchServiceUtils.getLowerCaseTerms(ontologyTerm2).stream().anyMatch(str -> {
            return list.contains(Stemmer.cleanStemPhrase(str));
        });
    }

    QueryRule createDisMaxQueryRuleForTerms(List<String> list, Float f) {
        ArrayList newArrayList = Lists.newArrayList();
        Sets.newLinkedHashSet(list).stream().filter((v0) -> {
            return StringUtils.isNotEmpty(v0);
        }).map(str -> {
            return QueryParser.escape(str).replace(ESCAPED_CARET_CHARACTER, CARET_CHARACTER);
        }).forEach(str2 -> {
            newArrayList.add(new QueryRule("label", QueryRule.Operator.FUZZY_MATCH, str2));
            newArrayList.add(new QueryRule("description", QueryRule.Operator.FUZZY_MATCH, str2));
        });
        QueryRule queryRule = null;
        if (newArrayList.size() > 0) {
            queryRule = new QueryRule(newArrayList);
            queryRule.setOperator(QueryRule.Operator.DIS_MAX);
        }
        if (queryRule != null && f != null && f.intValue() != 0) {
            queryRule.setValue(f);
        }
        return queryRule;
    }

    private QueryRule createShouldQueryRule(List<QueryRule> list, Float f) {
        QueryRule queryRule = null;
        if (list.size() > 0) {
            queryRule = new QueryRule(Lists.newArrayList());
            queryRule.setOperator(QueryRule.Operator.SHOULD);
            queryRule.getNestedRules().addAll(list);
        }
        if (queryRule != null && f != null && f.floatValue() > PackedInts.COMPACT) {
            queryRule.setValue(f);
        }
        return queryRule;
    }

    private String boostLexicalQuery(String str) {
        Map map = (Map) SemanticSearchServiceUtils.splitIntoUniqueTerms(str).stream().collect(Collectors.toMap(str2 -> {
            return str2;
        }, str3 -> {
            return Float.valueOf(this.termFrequencyService.getTermFrequency(str3));
        }));
        double orElse = map.values().stream().mapToDouble(f -> {
            return f.floatValue();
        }).max().orElse(1.0d);
        Map map2 = (Map) map.entrySet().stream().collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, entry -> {
            return new Float(((Float) entry.getValue()).floatValue() / orElse);
        }));
        return this.termJoiner.join((List) SemanticSearchServiceUtils.splitIntoUniqueTerms(str).stream().map(str4 -> {
            return str4 + CARET_CHARACTER + map2.get(str4);
        }).collect(Collectors.toList()));
    }

    private Float getBestInverseDocumentFrequency(Set<String> set) {
        Optional<String> findFirst = set.stream().sorted(new Comparator<String>() { // from class: org.molgenis.data.semanticsearch.service.impl.QueryExpansionServiceImpl.2
            @Override // java.util.Comparator
            public int compare(String str, String str2) {
                return Integer.compare(str.length(), str2.length());
            }
        }).findFirst();
        if (findFirst.isPresent()) {
            return Float.valueOf(this.termFrequencyService.getTermFrequency(findFirst.get()));
        }
        return null;
    }

    String parseQueryString(String str) {
        return this.termJoiner.join(SemanticSearchServiceUtils.splitRemoveStopWords(str));
    }

    String parseBoostQueryString(String str, double d) {
        return this.termJoiner.join((Iterable<?>) SemanticSearchServiceUtils.splitRemoveStopWords(str).stream().map(str2 -> {
            return str2 + CARET_CHARACTER + DECIMAL_FORMATTER.format(d);
        }).collect(Collectors.toList()));
    }
}
