/*
 * Decompiled with CFR 0.152.
 */
package com.databricks.internal.apache.arrow.vector;

import com.databricks.internal.apache.arrow.memory.ArrowBuf;
import com.databricks.internal.apache.arrow.memory.BoundsChecking;
import com.databricks.internal.apache.arrow.memory.BufferAllocator;
import com.databricks.internal.apache.arrow.memory.util.LargeMemoryUtil;
import com.databricks.internal.apache.arrow.memory.util.MemoryUtil;
import com.databricks.internal.apache.arrow.vector.ipc.message.ArrowFieldNode;
import com.databricks.internal.apache.arrow.vector.util.DataSizeRoundingUtil;

public class BitVectorHelper {
    private BitVectorHelper() {
    }

    public static long byteIndex(long absoluteBitIndex) {
        return absoluteBitIndex >> 3;
    }

    public static int bitIndex(long absoluteBitIndex) {
        return LargeMemoryUtil.checkedCastToInt(absoluteBitIndex & 7L);
    }

    public static int byteIndex(int absoluteBitIndex) {
        return absoluteBitIndex >> 3;
    }

    public static int bitIndex(int absoluteBitIndex) {
        return absoluteBitIndex & 7;
    }

    public static void setBit(ArrowBuf validityBuffer, long index) {
        long byteIndex = BitVectorHelper.byteIndex(index);
        int bitIndex = BitVectorHelper.bitIndex(index);
        int currentByte = validityBuffer.getByte(byteIndex);
        int bitMask = 1 << bitIndex;
        validityBuffer.setByte(byteIndex, currentByte |= bitMask);
    }

    public static void unsetBit(ArrowBuf validityBuffer, int index) {
        int byteIndex = BitVectorHelper.byteIndex(index);
        int bitIndex = BitVectorHelper.bitIndex(index);
        int currentByte = validityBuffer.getByte(byteIndex);
        int bitMask = 1 << bitIndex;
        validityBuffer.setByte((long)byteIndex, currentByte &= ~bitMask);
    }

    public static void setValidityBit(ArrowBuf validityBuffer, int index, int value) {
        int byteIndex = BitVectorHelper.byteIndex(index);
        int bitIndex = BitVectorHelper.bitIndex(index);
        int currentByte = validityBuffer.getByte(byteIndex);
        int bitMask = 1 << bitIndex;
        currentByte = value != 0 ? (currentByte |= bitMask) : (currentByte &= ~bitMask);
        validityBuffer.setByte((long)byteIndex, currentByte);
    }

    public static ArrowBuf setValidityBit(ArrowBuf validityBuffer, BufferAllocator allocator, int valueCount, int index, int value) {
        if (validityBuffer == null) {
            validityBuffer = allocator.buffer(BitVectorHelper.getValidityBufferSize(valueCount));
        }
        BitVectorHelper.setValidityBit(validityBuffer, index, value);
        if (index == valueCount - 1) {
            validityBuffer.writerIndex(BitVectorHelper.getValidityBufferSize(valueCount));
        }
        return validityBuffer;
    }

    public static int get(ArrowBuf buffer, int index) {
        int byteIndex = index >> 3;
        byte b = buffer.getByte(byteIndex);
        int bitIndex = index & 7;
        return b >> bitIndex & 1;
    }

    public static int getValidityBufferSize(int valueCount) {
        return DataSizeRoundingUtil.divideBy8Ceil(valueCount);
    }

    public static int getNullCount(ArrowBuf validityBuffer, int valueCount) {
        if (valueCount == 0) {
            return 0;
        }
        int count = 0;
        int sizeInBytes = BitVectorHelper.getValidityBufferSize(valueCount);
        int remainder = valueCount % 8;
        int fullBytesCount = remainder == 0 ? sizeInBytes : sizeInBytes - 1;
        int index = 0;
        while (index + 8 <= fullBytesCount) {
            long longValue = validityBuffer.getLong(index);
            count += Long.bitCount(longValue);
            index += 8;
        }
        if (index + 4 <= fullBytesCount) {
            int intValue = validityBuffer.getInt(index);
            count += Integer.bitCount(intValue);
            index += 4;
        }
        while (index < fullBytesCount) {
            byte byteValue = validityBuffer.getByte(index);
            count += Integer.bitCount(byteValue & 0xFF);
            ++index;
        }
        if (remainder != 0) {
            byte byteValue = validityBuffer.getByte(sizeInBytes - 1);
            byte mask = (byte)(255 << remainder);
            byteValue = (byte)(byteValue | mask);
            count += Integer.bitCount(byteValue & 0xFF);
        }
        return 8 * sizeInBytes - count;
    }

    public static boolean checkAllBitsEqualTo(ArrowBuf validityBuffer, int valueCount, boolean checkOneBits) {
        if (valueCount == 0) {
            return true;
        }
        int sizeInBytes = BitVectorHelper.getValidityBufferSize(valueCount);
        validityBuffer.checkBytes(0L, sizeInBytes);
        int remainder = valueCount % 8;
        int fullBytesCount = remainder == 0 ? sizeInBytes : sizeInBytes - 1;
        int intToCompare = checkOneBits ? -1 : 0;
        int index = 0;
        while (index + 8 <= fullBytesCount) {
            long longValue = MemoryUtil.UNSAFE.getLong(validityBuffer.memoryAddress() + (long)index);
            if (longValue != (long)intToCompare) {
                return false;
            }
            index += 8;
        }
        if (index + 4 <= fullBytesCount) {
            int intValue = MemoryUtil.UNSAFE.getInt(validityBuffer.memoryAddress() + (long)index);
            if (intValue != intToCompare) {
                return false;
            }
            index += 4;
        }
        while (index < fullBytesCount) {
            byte byteValue = MemoryUtil.UNSAFE.getByte(validityBuffer.memoryAddress() + (long)index);
            if (byteValue != (byte)intToCompare) {
                return false;
            }
            ++index;
        }
        if (remainder != 0) {
            byte byteValue = MemoryUtil.UNSAFE.getByte(validityBuffer.memoryAddress() + (long)sizeInBytes - 1L);
            byte mask = (byte)((1 << remainder) - 1);
            byteValue = (byte)(byteValue & mask);
            if (checkOneBits ? (mask & byteValue) != mask : byteValue != 0) {
                return false;
            }
        }
        return true;
    }

    public static byte getBitsFromCurrentByte(ArrowBuf data, int index, int offset) {
        return (byte)((data.getByte(index) & 0xFF) >>> offset);
    }

    public static byte getBitsFromNextByte(ArrowBuf data, int index, int offset) {
        return (byte)(data.getByte(index) << 8 - offset);
    }

    public static ArrowBuf loadValidityBuffer(ArrowFieldNode fieldNode, ArrowBuf sourceValidityBuffer, BufferAllocator allocator) {
        boolean isValidityBufferNull;
        int valueCount = fieldNode.getLength();
        ArrowBuf newBuffer = null;
        boolean bl = isValidityBufferNull = sourceValidityBuffer == null || sourceValidityBuffer.capacity() == 0L;
        if (isValidityBufferNull && (fieldNode.getNullCount() == 0 || fieldNode.getNullCount() == valueCount)) {
            newBuffer = allocator.buffer(BitVectorHelper.getValidityBufferSize(valueCount));
            newBuffer.setZero(0L, newBuffer.capacity());
            if (fieldNode.getNullCount() != 0) {
                return newBuffer;
            }
            int fullBytesCount = valueCount / 8;
            newBuffer.setOne(0, fullBytesCount);
            int remainder = valueCount % 8;
            if (remainder > 0) {
                byte bitMask = (byte)(255L >>> (8 - remainder & 7));
                newBuffer.setByte((long)fullBytesCount, bitMask);
            }
        } else {
            newBuffer = sourceValidityBuffer.getReferenceManager().retain(sourceValidityBuffer, allocator);
        }
        return newBuffer;
    }

    static void setBitMaskedByte(ArrowBuf data, int byteIndex, byte bitMask) {
        byte currentByte = data.getByte(byteIndex);
        currentByte = (byte)(currentByte | bitMask);
        data.setByte((long)byteIndex, currentByte);
    }

    public static void concatBits(ArrowBuf input1, int numBits1, ArrowBuf input2, int numBits2, ArrowBuf output) {
        int numBytes1 = DataSizeRoundingUtil.divideBy8Ceil(numBits1);
        int numBytes2 = DataSizeRoundingUtil.divideBy8Ceil(numBits2);
        int numBytesOut = DataSizeRoundingUtil.divideBy8Ceil(numBits1 + numBits2);
        if (BoundsChecking.BOUNDS_CHECKING_ENABLED) {
            output.checkBytes(0L, numBytesOut);
        }
        if (input1 != output) {
            MemoryUtil.UNSAFE.copyMemory(input1.memoryAddress(), output.memoryAddress(), numBytes1);
        }
        if (BitVectorHelper.bitIndex(numBits1) == 0) {
            MemoryUtil.UNSAFE.copyMemory(input2.memoryAddress(), output.memoryAddress() + (long)numBytes1, numBytes2);
            return;
        }
        int numBitsToFill = 8 - BitVectorHelper.bitIndex(numBits1);
        int mask = (1 << 8 - numBitsToFill) - 1;
        int numFullBytes = numBits2 / 8;
        int prevByte = output.getByte(numBytes1 - 1) & mask;
        for (int i = 0; i < numFullBytes; ++i) {
            int curByte = input2.getByte(i) & 0xFF;
            int byteToFill = curByte << 8 - numBitsToFill & 0xFF;
            output.setByte((long)(numBytes1 + i - 1), byteToFill | prevByte);
            prevByte = curByte >>> numBitsToFill;
        }
        int lastOutputByte = prevByte;
        int numTrailingBits = BitVectorHelper.bitIndex(numBits2);
        if (numTrailingBits == 0) {
            output.setByte((long)(numBytes1 + numFullBytes - 1), lastOutputByte);
            return;
        }
        int remByte = input2.getByte(numBytes2 - 1) & 0xFF;
        int byteToFill = remByte << 8 - numBitsToFill;
        output.setByte((long)(numBytes1 + numFullBytes - 1), lastOutputByte |= byteToFill);
        if (numTrailingBits > numBitsToFill) {
            output.setByte((long)(numBytes1 + numFullBytes), 0);
            int leftByte = remByte >>> numBitsToFill;
            output.setByte((long)(numBytes1 + numFullBytes), leftByte);
        }
    }
}

