/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.lucene.search;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Objects;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.PriorityQueue;

/**
 * {@link BulkScorer} that is used for pure disjunctions and disjunctions that have low values of
 * {@link BooleanQuery.Builder#setMinimumNumberShouldMatch(int)} and dense clauses. This scorer
 * scores documents by batches of 2048 docs.
 */
final class BooleanScorer extends BulkScorer {

  static final int SHIFT = 11;
  static final int SIZE = 1 << SHIFT;
  static final int MASK = SIZE - 1;
  static final int SET_SIZE = 1 << (SHIFT - 6);
  static final int SET_MASK = SET_SIZE - 1;

  static class Bucket {
    double score;
    int freq;
  }

  private class BulkScorerAndDoc {
    final BulkScorer scorer;
    final long cost;
    int next;

    BulkScorerAndDoc(BulkScorer scorer) {
      this.scorer = scorer;
      this.cost = scorer.cost();
      this.next = -1;
    }

    void advance(int min) throws IOException {
      score(orCollector, null, min, min);
    }

    void score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException {
      next = scorer.score(collector, acceptDocs, min, max);
    }
  }

  // See WANDScorer for an explanation
  private static long cost(Collection<BulkScorer> scorers, int minShouldMatch) {
    final PriorityQueue<BulkScorer> pq =
        new PriorityQueue<BulkScorer>(scorers.size() - minShouldMatch + 1) {
          @Override
          protected boolean lessThan(BulkScorer a, BulkScorer b) {
            return a.cost() > b.cost();
          }
        };
    for (BulkScorer scorer : scorers) {
      pq.insertWithOverflow(scorer);
    }
    long cost = 0;
    for (BulkScorer scorer = pq.pop(); scorer != null; scorer = pq.pop()) {
      cost += scorer.cost();
    }
    return cost;
  }

  static final class HeadPriorityQueue extends PriorityQueue<BulkScorerAndDoc> {

    public HeadPriorityQueue(int maxSize) {
      super(maxSize);
    }

    @Override
    protected boolean lessThan(BulkScorerAndDoc a, BulkScorerAndDoc b) {
      return a.next < b.next;
    }
  }

  static final class TailPriorityQueue extends PriorityQueue<BulkScorerAndDoc> {

    public TailPriorityQueue(int maxSize) {
      super(maxSize);
    }

    @Override
    protected boolean lessThan(BulkScorerAndDoc a, BulkScorerAndDoc b) {
      return a.cost < b.cost;
    }

    public BulkScorerAndDoc get(int i) {
      Objects.checkIndex(i, size());
      return (BulkScorerAndDoc) getHeapArray()[1 + i];
    }
  }

  // One bucket per doc ID in the window, non-null if scores are needed or if frequencies need to be
  // counted
  final Bucket[] buckets;
  // This is basically an inlined FixedBitSet... seems to help with bound checks
  final long[] matching = new long[SET_SIZE];

  final BulkScorerAndDoc[] leads;
  final HeadPriorityQueue head;
  final TailPriorityQueue tail;
  final Score score = new Score();
  final int minShouldMatch;
  final long cost;
  final boolean needsScores;

  final class OrCollector implements LeafCollector {
    Scorable scorer;

    @Override
    public void setScorer(Scorable scorer) {
      this.scorer = scorer;
    }

    @Override
    public void collect(int doc) throws IOException {
      final int i = doc & MASK;
      final int idx = i >>> 6;
      matching[idx] |= 1L << i;
      if (buckets != null) {
        final Bucket bucket = buckets[i];
        bucket.freq++;
        if (needsScores) {
          bucket.score += scorer.score();
        }
      }
    }
  }

  final OrCollector orCollector = new OrCollector();

  final class DocIdStreamView extends DocIdStream {

    int base;

    @Override
    public void forEach(CheckedIntConsumer<IOException> consumer) throws IOException {
      long[] matching = BooleanScorer.this.matching;
      Bucket[] buckets = BooleanScorer.this.buckets;
      int base = this.base;
      for (int idx = 0; idx < matching.length; idx++) {
        long bits = matching[idx];
        while (bits != 0L) {
          int ntz = Long.numberOfTrailingZeros(bits);
          if (buckets != null) {
            final int indexInWindow = (idx << 6) | ntz;
            final Bucket bucket = buckets[indexInWindow];
            if (bucket.freq >= minShouldMatch) {
              score.score = (float) bucket.score;
              consumer.accept(base | indexInWindow);
            }
            bucket.freq = 0;
            bucket.score = 0;
          } else {
            consumer.accept(base | (idx << 6) | ntz);
          }
          bits ^= 1L << ntz;
        }
      }
    }

    @Override
    public int count() throws IOException {
      if (minShouldMatch > 1) {
        // We can't just count bits in that case
        return super.count();
      }
      int count = 0;
      for (long l : matching) {
        count += Long.bitCount(l);
      }
      return count;
    }
  }

  private final DocIdStreamView docIdStreamView = new DocIdStreamView();

  BooleanScorer(Collection<BulkScorer> scorers, int minShouldMatch, boolean needsScores) {
    if (minShouldMatch < 1 || minShouldMatch > scorers.size()) {
      throw new IllegalArgumentException(
          "minShouldMatch should be within 1..num_scorers. Got " + minShouldMatch);
    }
    if (scorers.size() <= 1) {
      throw new IllegalArgumentException(
          "This scorer can only be used with two scorers or more, got " + scorers.size());
    }
    if (needsScores || minShouldMatch > 1) {
      buckets = new Bucket[SIZE];
      for (int i = 0; i < buckets.length; i++) {
        buckets[i] = new Bucket();
      }
    } else {
      buckets = null;
    }
    this.leads = new BulkScorerAndDoc[scorers.size()];
    this.head = new HeadPriorityQueue(scorers.size() - minShouldMatch + 1);
    this.tail = new TailPriorityQueue(minShouldMatch - 1);
    this.minShouldMatch = minShouldMatch;
    this.needsScores = needsScores;
    for (BulkScorer scorer : scorers) {
      final BulkScorerAndDoc evicted = tail.insertWithOverflow(new BulkScorerAndDoc(scorer));
      if (evicted != null) {
        head.add(evicted);
      }
    }
    this.cost = cost(scorers, minShouldMatch);
  }

  @Override
  public long cost() {
    return cost;
  }

  private void scoreWindowIntoBitSetAndReplay(
      LeafCollector collector,
      Bits acceptDocs,
      int base,
      int min,
      int max,
      BulkScorerAndDoc[] scorers,
      int numScorers)
      throws IOException {
    for (int i = 0; i < numScorers; ++i) {
      final BulkScorerAndDoc scorer = scorers[i];
      assert scorer.next < max;
      scorer.score(orCollector, acceptDocs, min, max);
    }

    docIdStreamView.base = base;
    collector.collect(docIdStreamView);

    Arrays.fill(matching, 0L);
  }

  private BulkScorerAndDoc advance(int min) throws IOException {
    assert tail.size() == minShouldMatch - 1;
    final HeadPriorityQueue head = this.head;
    final TailPriorityQueue tail = this.tail;
    BulkScorerAndDoc headTop = head.top();
    BulkScorerAndDoc tailTop = tail.top();
    while (headTop.next < min) {
      if (tailTop == null || headTop.cost <= tailTop.cost) {
        headTop.advance(min);
        headTop = head.updateTop();
      } else {
        // swap the top of head and tail
        final BulkScorerAndDoc previousHeadTop = headTop;
        tailTop.advance(min);
        headTop = head.updateTop(tailTop);
        tailTop = tail.updateTop(previousHeadTop);
      }
    }
    return headTop;
  }

  private void scoreWindowMultipleScorers(
      LeafCollector collector,
      Bits acceptDocs,
      int windowBase,
      int windowMin,
      int windowMax,
      int maxFreq)
      throws IOException {
    while (maxFreq < minShouldMatch && maxFreq + tail.size() >= minShouldMatch) {
      // a match is still possible
      final BulkScorerAndDoc candidate = tail.pop();
      candidate.advance(windowMin);
      if (candidate.next < windowMax) {
        leads[maxFreq++] = candidate;
      } else {
        head.add(candidate);
      }
    }

    if (maxFreq >= minShouldMatch) {
      // There might be matches in other scorers from the tail too
      for (int i = 0; i < tail.size(); ++i) {
        leads[maxFreq++] = tail.get(i);
      }
      tail.clear();

      scoreWindowIntoBitSetAndReplay(
          collector, acceptDocs, windowBase, windowMin, windowMax, leads, maxFreq);
    }

    // Push back scorers into head and tail
    for (int i = 0; i < maxFreq; ++i) {
      final BulkScorerAndDoc evicted = head.insertWithOverflow(leads[i]);
      if (evicted != null) {
        tail.add(evicted);
      }
    }
  }

  private void scoreWindowSingleScorer(
      BulkScorerAndDoc bulkScorer,
      LeafCollector collector,
      Bits acceptDocs,
      int windowMin,
      int windowMax,
      int max)
      throws IOException {
    assert tail.size() == 0;
    final int nextWindowBase = head.top().next & ~MASK;
    final int end = Math.max(windowMax, Math.min(max, nextWindowBase));

    bulkScorer.score(collector, acceptDocs, windowMin, end);

    // reset the scorer that should be used for the general case
    collector.setScorer(score);
  }

  private BulkScorerAndDoc scoreWindow(
      BulkScorerAndDoc top, LeafCollector collector, Bits acceptDocs, int min, int max)
      throws IOException {
    final int windowBase = top.next & ~MASK; // find the window that the next match belongs to
    final int windowMin = Math.max(min, windowBase);
    final int windowMax = Math.min(max, windowBase + SIZE);

    // Fill 'leads' with all scorers from 'head' that are in the right window
    leads[0] = head.pop();
    int maxFreq = 1;
    while (head.size() > 0 && head.top().next < windowMax) {
      leads[maxFreq++] = head.pop();
    }

    if (minShouldMatch == 1 && maxFreq == 1) {
      // special case: only one scorer can match in the current window,
      // we can collect directly
      final BulkScorerAndDoc bulkScorer = leads[0];
      scoreWindowSingleScorer(bulkScorer, collector, acceptDocs, windowMin, windowMax, max);
      return head.add(bulkScorer);
    } else {
      // general case, collect through a bit set first and then replay
      scoreWindowMultipleScorers(collector, acceptDocs, windowBase, windowMin, windowMax, maxFreq);
      return head.top();
    }
  }

  @Override
  public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException {
    collector.setScorer(score);

    BulkScorerAndDoc top = advance(min);
    while (top.next < max) {
      top = scoreWindow(top, collector, acceptDocs, min, max);
    }

    return top.next;
  }
}
