/*
 * Decompiled with CFR 0.152.
 */
package org.tinyradius.packet;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.tinyradius.attribute.AttributeType;
import org.tinyradius.attribute.Attributes;
import org.tinyradius.attribute.RadiusAttribute;
import org.tinyradius.attribute.VendorSpecificAttribute;
import org.tinyradius.dictionary.Dictionary;
import org.tinyradius.packet.PacketType;
import org.tinyradius.packet.RadiusPackets;
import org.tinyradius.util.RadiusPacketException;

public class RadiusPacket {
    public static final int HEADER_LENGTH = 20;
    private static final int VENDOR_SPECIFIC_TYPE = 26;
    private final int type;
    private final int identifier;
    private final List<RadiusAttribute> attributes;
    private final byte[] authenticator;
    private final Dictionary dictionary;

    public RadiusPacket(Dictionary dictionary, int type, int identifier) {
        this(dictionary, type, identifier, null, new ArrayList<RadiusAttribute>());
    }

    public RadiusPacket(Dictionary dictionary, int type, int identifier, byte[] authenticator) {
        this(dictionary, type, identifier, authenticator, new ArrayList<RadiusAttribute>());
    }

    public RadiusPacket(Dictionary dictionary, int type, int identifier, List<RadiusAttribute> attributes) {
        this(dictionary, type, identifier, null, attributes);
    }

    public RadiusPacket(Dictionary dictionary, int type, int identifier, byte[] authenticator, List<RadiusAttribute> attributes) {
        if (type < 1 || type > 255) {
            throw new IllegalArgumentException("packet type out of bounds: " + type);
        }
        if (identifier < 0 || identifier > 255) {
            throw new IllegalArgumentException("packet identifier out of bounds: " + identifier);
        }
        if (authenticator != null && authenticator.length != 16) {
            throw new IllegalArgumentException("authenticator must be 16 octets, actual: " + authenticator.length);
        }
        this.type = type;
        this.identifier = identifier;
        this.authenticator = authenticator;
        this.attributes = new ArrayList<RadiusAttribute>(attributes);
        this.dictionary = Objects.requireNonNull(dictionary, "dictionary is null");
    }

    public int getIdentifier() {
        return this.identifier;
    }

    public int getType() {
        return this.type;
    }

    public void addAttribute(RadiusAttribute attribute) {
        Objects.requireNonNull(this.attributes, "Attribute is null");
        if (attribute.getVendorId() == -1) {
            this.attributes.add(Attributes.createAttribute(this.dictionary, attribute.getVendorId(), attribute.getType(), attribute.getValue()));
        } else {
            VendorSpecificAttribute vsa = new VendorSpecificAttribute(this.dictionary, attribute.getVendorId());
            vsa.addSubAttribute(attribute);
            this.attributes.add(vsa);
        }
    }

    public void addAttribute(String typeName, String value) {
        if (typeName == null || typeName.isEmpty()) {
            throw new IllegalArgumentException("type name is empty");
        }
        if (value == null || value.isEmpty()) {
            throw new IllegalArgumentException("value is empty");
        }
        AttributeType type = this.dictionary.getAttributeTypeByName(typeName);
        if (type == null) {
            throw new IllegalArgumentException("unknown attribute type '" + typeName + "'");
        }
        RadiusAttribute attribute = Attributes.createAttribute(this.getDictionary(), type.getVendorId(), type.getTypeCode(), value);
        this.addAttribute(attribute);
    }

    public void removeAttribute(RadiusAttribute attribute) {
        if (attribute.getVendorId() == -1 || attribute.getType() == 26) {
            this.attributes.remove(attribute);
        } else {
            List<VendorSpecificAttribute> vsas = this.getVendorAttributes(attribute.getVendorId());
            for (VendorSpecificAttribute vsa : vsas) {
                vsa.removeSubAttribute(attribute);
                if (!vsa.getSubAttributes().isEmpty()) continue;
                this.removeAttribute(vsa);
            }
        }
    }

    public void removeAttributes(int type) {
        this.attributes.removeIf(a -> a.getType() == type);
    }

    public void removeLastAttribute(int type) {
        List<RadiusAttribute> attrs = this.getAttributes(type);
        if (attrs == null || attrs.isEmpty()) {
            return;
        }
        this.removeAttribute(attrs.get(attrs.size() - 1));
    }

    public void removeAttributes(int vendorId, int typeCode) {
        if (vendorId == -1) {
            this.removeAttributes(typeCode);
            return;
        }
        List<VendorSpecificAttribute> vsas = this.getVendorAttributes(vendorId);
        for (VendorSpecificAttribute vsa : vsas) {
            List<RadiusAttribute> sas = vsa.getSubAttributes();
            sas.removeIf(attr -> attr.getType() == typeCode && attr.getVendorId() == vendorId);
            if (!sas.isEmpty()) continue;
            this.removeAttribute(vsa);
        }
    }

    public List<RadiusAttribute> getAttributes(int type) {
        return this.attributes.stream().filter(a -> a.getType() == type).collect(Collectors.toList());
    }

    public List<RadiusAttribute> getAttributes(int vendorId, int attributeType) {
        if (vendorId == -1) {
            return this.getAttributes(attributeType);
        }
        return this.getVendorAttributes(vendorId).stream().map(VendorSpecificAttribute::getSubAttributes).flatMap(Collection::stream).filter(sa -> sa.getType() == attributeType && sa.getVendorId() == vendorId).collect(Collectors.toList());
    }

    public List<RadiusAttribute> getAttributes() {
        return this.attributes;
    }

    public RadiusAttribute getAttribute(int type) {
        List<RadiusAttribute> attrs = this.getAttributes(type);
        if (attrs.size() > 1) {
            throw new RuntimeException("multiple attributes of requested type " + type);
        }
        return attrs.isEmpty() ? null : attrs.get(0);
    }

    public RadiusAttribute getAttribute(int vendorId, int type) {
        if (vendorId == -1) {
            return this.getAttribute(type);
        }
        List<RadiusAttribute> attrs = this.getAttributes(vendorId, type);
        if (attrs.size() > 1) {
            throw new RuntimeException("multiple attributes of requested type " + type);
        }
        return attrs.isEmpty() ? null : attrs.get(0);
    }

    public RadiusAttribute getAttribute(String type) {
        if (type == null || type.isEmpty()) {
            throw new IllegalArgumentException("type name is empty");
        }
        AttributeType t = this.dictionary.getAttributeTypeByName(type);
        if (t == null) {
            throw new IllegalArgumentException("unknown attribute type name '" + type + "'");
        }
        return this.getAttribute(t.getVendorId(), t.getTypeCode());
    }

    public String getAttributeValue(String type) {
        RadiusAttribute attr = this.getAttribute(type);
        return attr == null ? null : attr.getValueString();
    }

    public List<VendorSpecificAttribute> getVendorAttributes(int vendorId) {
        return this.getAttributes(26).stream().filter(VendorSpecificAttribute.class::isInstance).map(VendorSpecificAttribute.class::cast).filter(a -> a.getVendorId() == vendorId).collect(Collectors.toList());
    }

    public RadiusPacket encodeRequest(String sharedSecret) {
        return this.encodeResponse(sharedSecret, new byte[16]);
    }

    public RadiusPacket encodeResponse(String sharedSecret, byte[] requestAuthenticator) {
        byte[] authenticator = this.createHashedAuthenticator(sharedSecret, requestAuthenticator);
        return RadiusPackets.create(this.dictionary, this.type, this.identifier, authenticator, this.attributes);
    }

    public byte[] getAuthenticator() {
        return this.authenticator == null ? null : (byte[])this.authenticator.clone();
    }

    public Dictionary getDictionary() {
        return this.dictionary;
    }

    protected byte[] createHashedAuthenticator(String sharedSecret, byte[] requestAuthenticator) {
        Objects.requireNonNull(requestAuthenticator, "Authenticator cannot be null");
        if (sharedSecret == null || sharedSecret.isEmpty()) {
            throw new IllegalArgumentException("Shared secret cannot be null/empty");
        }
        byte[] attributes = this.getAttributeBytes();
        int packetLength = 20 + attributes.length;
        MessageDigest md5 = RadiusPacket.getMd5Digest();
        md5.update((byte)this.getType());
        md5.update((byte)this.getIdentifier());
        md5.update((byte)(packetLength >> 8));
        md5.update((byte)(packetLength & 0xFF));
        md5.update(requestAuthenticator);
        md5.update(attributes);
        return md5.digest(sharedSecret.getBytes(StandardCharsets.UTF_8));
    }

    public void verify(String sharedSecret, byte[] requestAuthenticator) throws RadiusPacketException {
        byte[] expectedAuth = this.createHashedAuthenticator(sharedSecret, requestAuthenticator);
        byte[] receivedAuth = this.getAuthenticator();
        if (receivedAuth.length != 16 || !Arrays.equals(expectedAuth, receivedAuth)) {
            throw new RadiusPacketException("Authenticator check failed (bad authenticator or shared secret)");
        }
    }

    static MessageDigest getMd5Digest() {
        try {
            return MessageDigest.getInstance("MD5");
        }
        catch (NoSuchAlgorithmException e) {
            throw new RuntimeException(e);
        }
    }

    protected byte[] getAttributeBytes() {
        ByteBuf buffer = Unpooled.buffer();
        for (RadiusAttribute attribute : this.attributes) {
            buffer.writeBytes(attribute.toByteArray());
        }
        return buffer.copy().array();
    }

    public Map<String, String> getAttributeMap() {
        HashMap<String, String> map = new HashMap<String, String>();
        this.attributes.forEach(a -> map.putAll(a.toAttributeMap()));
        return map;
    }

    public String toString() {
        StringBuilder s = new StringBuilder();
        s.append(PacketType.getPacketTypeName(this.getType()));
        s.append(", ID ");
        s.append(this.identifier);
        for (RadiusAttribute attr : this.attributes) {
            s.append("\n");
            s.append(attr.toString());
        }
        return s.toString();
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (!(o instanceof RadiusPacket)) {
            return false;
        }
        RadiusPacket that = (RadiusPacket)o;
        return this.type == that.type && this.identifier == that.identifier && Objects.equals(this.attributes, that.attributes) && Arrays.equals(this.authenticator, that.authenticator) && Objects.equals(this.dictionary, that.dictionary);
    }

    public int hashCode() {
        int result = Objects.hash(this.type, this.identifier, this.attributes, this.dictionary);
        result = 31 * result + Arrays.hashCode(this.authenticator);
        return result;
    }

    public RadiusPacket copy() {
        return RadiusPackets.create(this.getDictionary(), this.getType(), this.getIdentifier(), this.getAuthenticator(), new ArrayList<RadiusAttribute>(this.getAttributes()));
    }
}

