/*
 * Copyright 2021 the original author or authors.
 * <p>
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * <p>
 * https://www.apache.org/licenses/LICENSE-2.0
 * <p>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.openrewrite.java.dependencies;

import com.fasterxml.jackson.databind.MappingIterator;
import com.fasterxml.jackson.dataformat.csv.CsvMapper;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import lombok.*;
import lombok.experimental.NonFinal;
import org.jspecify.annotations.Nullable;
import org.openrewrite.*;
import org.openrewrite.gradle.marker.GradleDependencyConfiguration;
import org.openrewrite.gradle.marker.GradleProject;
import org.openrewrite.groovy.GroovyIsoVisitor;
import org.openrewrite.groovy.GroovyVisitor;
import org.openrewrite.groovy.tree.G;
import org.openrewrite.internal.StringUtils;
import org.openrewrite.java.dependencies.internal.StaticVersionComparator;
import org.openrewrite.java.dependencies.internal.Version;
import org.openrewrite.java.dependencies.internal.VersionParser;
import org.openrewrite.java.dependencies.table.VulnerabilityReport;
import org.openrewrite.java.marker.JavaProject;
import org.openrewrite.marker.CommitMessage;
import org.openrewrite.maven.*;
import org.openrewrite.maven.internal.MavenPomDownloader;
import org.openrewrite.maven.table.MavenMetadataFailures;
import org.openrewrite.maven.tree.*;
import org.openrewrite.semver.LatestPatch;
import org.openrewrite.xml.tree.Xml;

import java.io.IOException;
import java.io.InputStream;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;

@Value
@EqualsAndHashCode(callSuper = false)
public class DependencyVulnerabilityCheck extends ScanningRecipe<DependencyVulnerabilityCheck.Accumulator> {
    transient MavenMetadataFailures metadataFailures = new MavenMetadataFailures(this);
    transient VersionParser versionParser = new VersionParser();
    transient VulnerabilityReport report = new VulnerabilityReport(this);

    @Option(displayName = "Scope",
            description = "Match dependencies with the specified scope. Default is `compile`.",
            valid = {"compile", "test", "runtime", "provided"},
            example = "compile",
            required = false)
    @Nullable
    String scope;

    @Option(displayName = "Override transitives",
            description = "When enabled transitive dependencies with vulnerabilities will have their versions overridden. " +
                          "By default only direct dependencies have their version numbers upgraded.",
            example = "false",
            required = false)
    @Nullable
    Boolean overrideTransitive;

    @Override
    public String getDisplayName() {
        return "Find and fix vulnerable dependencies";
    }

    @Override
    public String getDescription() {
        //language=markdown
        return "This software composition analysis (SCA) tool detects and upgrades dependencies with publicly disclosed vulnerabilities. " +
               "This recipe both generates a report of vulnerable dependencies and upgrades to newer versions with fixes. " +
               "This recipe **only** upgrades to the latest **patch** version.  If a minor or major upgrade is required to reach the fixed version, this recipe will not make any changes. " +
               "Vulnerability information comes from the [GitHub Security Advisory Database](https://docs.github.com/en/code-security/security-advisories/global-security-advisories/about-the-github-advisory-database), " +
               "which aggregates vulnerability data from several public databases, including the [National Vulnerability Database](https://nvd.nist.gov/) maintained by the United States government. " +
               "Dependencies following [Semantic Versioning](https://semver.org/) will see their _patch_ version updated where applicable.";
    }

    @Override
    public Validated<Object> validate() {
        return super.validate().and(Validated.test("scope", "scope is a valid Maven scope", scope, s -> {
            try {
                Scope.fromName(s);
                return true;
            } catch (Throwable t) {
                return false;
            }
        }));
    }

    @Getter
    @RequiredArgsConstructor
    public static class Accumulator {
        final Map<GroupArtifact, List<Vulnerability>> db;
        final Scope scope;
        final org.openrewrite.java.dependencies.UpgradeDependencyVersion.Accumulator dependencyAcc;
        final AddManagedDependency.Scanned transitiveAcc;


        Map<String, Vulnerabilities> projectToVulnerabilities = new LinkedHashMap<>();

        public void repositoriesFrom(SourceFile s) {
            s.getMarkers().findFirst(MavenResolutionResult.class)
                    .ifPresent(mrr -> repositories.addAll(mrr.getPom().getRepositories()));
            s.getMarkers().findFirst(GradleProject.class)
                    .ifPresent(gradleProject -> repositories.addAll(gradleProject.getMavenRepositories()));
        }

        private Set<MavenRepository> repositories = new LinkedHashSet<>();

        @Nullable
        private List<MavenRepository> allRepositories = null;
        public List<MavenRepository> getRepositories() {
            if (allRepositories == null) {
                allRepositories = new ArrayList<>(repositories);
            }
            return allRepositories;
        }

        @Nullable
        private Map<ResolvedGroupArtifactVersion, Set<MinimumDepthVulnerability>> upgradeableVulnerabilities = null;

        public Map<ResolvedGroupArtifactVersion, Set<MinimumDepthVulnerability>> upgradeableVulnerabilities() {
            if (upgradeableVulnerabilities == null) {
                upgradeableVulnerabilities = new LinkedHashMap<>();
                for (Vulnerabilities vuln : projectToVulnerabilities.values()) {
                    for (Map.Entry<ResolvedGroupArtifactVersion, Set<MinimumDepthVulnerability>> resolvedGroupArtifactVersionSetEntry : vuln.getGavToVulnerabilities().entrySet()) {
                        ResolvedGroupArtifactVersion gav = resolvedGroupArtifactVersionSetEntry.getKey();
                        Set<MinimumDepthVulnerability> vulnerabilities = resolvedGroupArtifactVersionSetEntry.getValue();
                        upgradeableVulnerabilities.compute(gav, (k, upgradeableSoFar) -> {
                            Set<MinimumDepthVulnerability> newUpgradableVulnerabilities = vulnerabilities.stream()
                                    .filter(it -> StringUtils.isNotEmpty(it.vulnerability.getFixedVersion()))
                                    .filter(it -> new LatestPatch(null)
                                            .isValid(gav.getVersion(), it.vulnerability.getFixedVersion()))
                                    .collect(Collectors.toCollection(LinkedHashSet::new));
                            if (newUpgradableVulnerabilities.isEmpty()) {
                                return upgradeableSoFar;
                            }
                            if (upgradeableSoFar == null) {
                                upgradeableSoFar = newUpgradableVulnerabilities;
                            } else {
                                upgradeableSoFar.addAll(newUpgradableVulnerabilities);
                            }

                            return upgradeableSoFar;
                        });

                    }
                }
            }
            return upgradeableVulnerabilities;
        }
    }

    @Value
    public static class Vulnerabilities {
        Map<ResolvedGroupArtifactVersion, Set<MinimumDepthVulnerability>> gavToVulnerabilities;

        public @Nullable Set<MinimumDepthVulnerability> computeIfAbsent(ResolvedGroupArtifactVersion gav, Function<ResolvedGroupArtifactVersion, Set<MinimumDepthVulnerability>> mappingFunction) {
            return gavToVulnerabilities.computeIfAbsent(gav, mappingFunction);
        }
    }

    @Override
    public Accumulator getInitialValue(ExecutionContext ctx) {
        Scope parsedScope = Scope.fromName(scope);
        CsvMapper csvMapper = new CsvMapper();
        csvMapper.registerModule(new JavaTimeModule());
        Map<GroupArtifact, List<Vulnerability>> db = new HashMap<>();

        try (InputStream resourceAsStream = DependencyVulnerabilityCheck.class.getResourceAsStream("/advisories-maven.csv");
             MappingIterator<Vulnerability> vs = csvMapper.readerWithSchemaFor(Vulnerability.class).readValues(resourceAsStream)) {
            vs.forEachRemaining(v -> {
                String[] ga = v.getGroupArtifact().split(":");
                db.computeIfAbsent(new GroupArtifact(ga[0], ga[1]), g -> new ArrayList<>()).add(v);
            });
        } catch (IOException e) {
            throw new RuntimeException(e);
        }

        return new Accumulator(db, parsedScope,
                new UpgradeDependencyVersion("", "", "", null, null, null)
                        .getInitialValue(ctx),
                new UpgradeTransitiveDependencyVersion("", "", "", null, null, null, null, null, null, null, true)
                        .getInitialValue(ctx));
    }

    @Override
    public TreeVisitor<?, ExecutionContext> getScanner(Accumulator acc) {
        return new TreeVisitor<Tree, ExecutionContext>() {
            @Override
            public @Nullable Tree visit(@Nullable Tree tree, ExecutionContext ctx) {
                if (!(tree instanceof SourceFile)) {
                    return tree;
                }
                acc.repositoriesFrom((SourceFile) tree);
                scanMaven(acc.getDb(), acc.getProjectToVulnerabilities(), acc.getScope()).visitNonNull(tree, ctx);
                scanGradleGroovy(acc.getDb(), acc.getProjectToVulnerabilities(), acc.getScope()).visitNonNull(tree, ctx);
                new org.openrewrite.java.dependencies.UpgradeDependencyVersion("", "", "", null, null, null)
                        .getScanner(acc.getDependencyAcc())
                        .visit(tree, ctx);
                new org.openrewrite.java.dependencies.UpgradeTransitiveDependencyVersion("", "", "", null, null, null, null, null, null, null, true)
                        .getScanner(acc.getTransitiveAcc())
                        .visit(tree, ctx);
                return tree;
            }
        };
    }

    @Override
    public Collection<SourceFile> generate(Accumulator acc, ExecutionContext ctx) {
        for (Map.Entry<String, Vulnerabilities> projectToVulnerabilities : acc.getProjectToVulnerabilities().entrySet()) {
            String projectName = projectToVulnerabilities.getKey();
            for (Map.Entry<ResolvedGroupArtifactVersion, Set<MinimumDepthVulnerability>> vulnerabilitiesByGav : projectToVulnerabilities.getValue().getGavToVulnerabilities().entrySet()) {
                for (MinimumDepthVulnerability vDepth : vulnerabilitiesByGav.getValue()) {
                    Vulnerability v = vDepth.getVulnerability();
                    ResolvedGroupArtifactVersion gav = vulnerabilitiesByGav.getKey();
                    boolean fixWithVersionUpdateOnly = new LatestPatch(null).isValid(gav.getVersion(), v.getFixedVersion());
                    report.insertRow(ctx, new VulnerabilityReport.Row(
                            projectName,
                            v.getCve(),
                            gav.getGroupId(),
                            gav.getArtifactId(),
                            gav.getVersion(),
                            v.getFixedVersion(),
                            fixWithVersionUpdateOnly,
                            v.getSummary(),
                            v.getSeverity().toString(),
                            vDepth.getMinDepth(),
                            v.getCwes()
                    ));
                }
            }
        }
        return Collections.emptyList();
    }

    @Override
    public TreeVisitor<?, ExecutionContext> getVisitor(Accumulator acc) {
        return new TreeVisitor<Tree, ExecutionContext>() {
            @Override
            public @Nullable Tree visit(@Nullable Tree tree, ExecutionContext ctx) {
                if (tree == null) {
                    return null;
                }
                Tree t = tree;
                Map<ResolvedGroupArtifactVersion, Set<MinimumDepthVulnerability>> upgradeableVulnerabilities =
                        acc.upgradeableVulnerabilities();
                for (Map.Entry<ResolvedGroupArtifactVersion, Set<MinimumDepthVulnerability>> gavToUpgradeableVulnerabilities : upgradeableVulnerabilities.entrySet()) {
                    ResolvedGroupArtifactVersion gav = gavToUpgradeableVulnerabilities.getKey();
                    Set<MinimumDepthVulnerability> vulnerabilities = gavToUpgradeableVulnerabilities.getValue();
                    String versionToRequest = versionToRequest(vulnerabilities, acc.getRepositories(), ctx);
                    Tree t2 = new UpgradeDependencyVersion(gav.getGroupId(), gav.getArtifactId(), versionToRequest, null, overrideTransitive, null)
                            .getVisitor(acc.getDependencyAcc())
                            .visitNonNull(t, ctx);
                    String because = null;
                    if (t2 == t) {
                        because = because(vulnerabilities);
                        t2 = new UpgradeTransitiveDependencyVersion(gav.getGroupId(), gav.getArtifactId(), versionToRequest, scope, null, null, null, because, null, null, true)
                                .getVisitor(acc.getTransitiveAcc())
                                .visitNonNull(t2, ctx);
                    }
                    t = t2;

                    if (t != tree) {
                        if (because == null) {
                            because = because(vulnerabilities);
                        }
                        CommitMessage.message(t2, DependencyVulnerabilityCheck.this, because);
                    }
                }
                return t;
            }
        };
    }

    /**
     * Of the vulnerabilities with valid upgrade paths, take the highest fixed version.
     * See if the highest fixed version can be resolved from the available repositories.
     * Sometimes a fix version in the database will slightly inaccurate, such as missing a suffix (milestone, timestamp, etc.).
     * If the fix version from the database cannot be validated to exist, leave discovery up to upgrade dependency
     * recipes by falling back to "latest.patch".
     */
    private String versionToRequest(Set<MinimumDepthVulnerability> vulnerabilities, List<MavenRepository> repositories, ExecutionContext ctx) {
        Comparator<Version> vc = new StaticVersionComparator();
        Vulnerability highestFix = vulnerabilities.stream()
                .max(Comparator.comparing(
                        it -> versionParser.transform(stripExtraneousVersionSuffix(it.getVulnerability().getFixedVersion())),
                        vc))
                .map(MinimumDepthVulnerability::getVulnerability)
                .orElse(null);
        if (highestFix != null) {
            String[] groupArtifact = highestFix.getGroupArtifact().split(":");
            String groupId = groupArtifact[0];
            String artifactId = groupArtifact[1];

            try {
                MavenMetadata metadata = metadataFailures.insertRows(ctx, () -> new MavenPomDownloader(ctx).downloadMetadata(
                        new GroupArtifact(groupId, artifactId), null, repositories));
                List<String> versions = metadata.getVersioning().getVersions();
                if (versions.contains(highestFix.getFixedVersion())) {
                    return highestFix.getFixedVersion();
                }
            } catch (MavenDownloadingException e) {
                return "latest.patch";
            }
        }
        return "latest.patch";
    }

    private static @Nullable String because(Collection<MinimumDepthVulnerability> reasons) {
        String because = reasons.stream()
                .map(MinimumDepthVulnerability::getVulnerability)
                .map(Vulnerability::getCve)
                .filter(StringUtils::isNotEmpty)
                .distinct()
                .collect(Collectors.joining(", "));
        return StringUtils.isBlank(because) ? null : because;
    }

    private MavenVisitor<ExecutionContext> scanMaven(
            Map<GroupArtifact, List<Vulnerability>> db,
            Map<String, Vulnerabilities> projectToVulnerabilities,
            Scope aScope) {
        return new MavenIsoVisitor<ExecutionContext>() {
            @Override
            public Xml.Document visitDocument(Xml.Document document, ExecutionContext ctx) {
                List<ResolvedDependency> scopeDependencies = getResolutionResult().getDependencies().get(aScope);
                if (scopeDependencies != null) {
                    String projectName = projectName(document);
                    for (ResolvedDependency resolvedDependency : scopeDependencies) {
                        analyzeDependency(db,
                                projectToVulnerabilities.computeIfAbsent(projectName, p -> new Vulnerabilities(new LinkedHashMap<>())),
                                resolvedDependency);
                    }
                }
                return document;
            }
        };
    }

    private static String projectName(Tree t) {
        return t.getMarkers().findFirst(JavaProject.class)
                .map(JavaProject::getProjectName)
                .orElse("");
    }

    private static boolean scopeExcludesConfiguration(GradleDependencyConfiguration configuration, Scope scope) {
        switch (scope) {
            case Test:
                return !configuration.getName().contains("test");
            case Compile:
            case Runtime:
                return configuration.getName().contains("test");
            case Provided:
                return !configuration.getName().contains("provided") && !configuration.getName().contains("compileOnly");
            default:
                return false;
        }
    }

    private GroovyVisitor<ExecutionContext> scanGradleGroovy(
            Map<GroupArtifact, List<Vulnerability>> db,
            Map<String, Vulnerabilities> projectToVulnerabilities,
            Scope aScope) {
        return new GroovyIsoVisitor<ExecutionContext>() {
            @Override
            public G.CompilationUnit visitCompilationUnit(G.CompilationUnit cu, ExecutionContext ctx) {
                cu.getMarkers().findFirst(GradleProject.class).ifPresent(gradleProject -> {
                    String projectName = projectName(cu);
                    for (GradleDependencyConfiguration configuration : gradleProject.getConfigurations()) {
                        if (scopeExcludesConfiguration(configuration, aScope)) {
                            continue;
                        }
                        for (ResolvedDependency resolvedDependency : configuration.getResolved()) {
                            if (!StringUtils.isBlank(resolvedDependency.getVersion())) {
                                analyzeDependency(db,
                                        projectToVulnerabilities.computeIfAbsent(projectName, p -> new Vulnerabilities(new LinkedHashMap<>())),
                                        resolvedDependency);
                            }
                        }
                    }
                });
                return cu;
            }
        };
    }

    private void analyzeDependency(
            Map<GroupArtifact, List<Vulnerability>> db,
            Vulnerabilities vulnerabilities,
            ResolvedDependency resolvedDependency) {
        List<Vulnerability> vs = db.get(new GroupArtifact(resolvedDependency.getGroupId(), resolvedDependency.getArtifactId()));
        if (vs != null) {
            Set<MinimumDepthVulnerability> gavVs = null;
            Comparator<Version> vc = new StaticVersionComparator();

            nextVulnerability:
            for (Vulnerability v : vs) {
                // Some dependencies have a ".RELEASE" suffix.
                // For example spring-security-core had a .RELEASE suffix for versions >=2.0.5 and <5.4.0. No suffixes since then
                // The vulnerability database is inconsistent about whether the ".RELEASE" is included in the fixed version
                // This inconsistency complicates comparisons because "5.3.0" != "5.3.0.RELEASE"
                // This inconsistency complicates dependency upgrade since we don't know which version number format to request
                // Therefore ignore the suffix during comparison but record it so that version upgrades can try both with and without the suffix
                // The edge case of ".RELEASE" being introduced into a version scheme between patch versions is possible but hopefully rare
                boolean isLessThanFixed = StringUtils.isBlank(v.getFixedVersion());
                if (!isLessThanFixed &&
                    vc.compare(
                        versionParser.transform(stripExtraneousVersionSuffix(v.getFixedVersion())),
                        versionParser.transform(stripExtraneousVersionSuffix(resolvedDependency.getVersion()))) > 0) {
                    isLessThanFixed = true;
                }

                if (isLessThanFixed &&
                    vc.compare(
                        versionParser.transform(stripExtraneousVersionSuffix(v.getIntroducedVersion())),
                        versionParser.transform(stripExtraneousVersionSuffix(resolvedDependency.getVersion()))) <= 0) {
                    if (gavVs == null) {
                        gavVs = vulnerabilities.computeIfAbsent(resolvedDependency.getGav(), ga -> new TreeSet<>(
                                Comparator.comparing((MinimumDepthVulnerability vDep) -> vDep.getVulnerability().getSeverity()).reversed()
                                        .thenComparing((MinimumDepthVulnerability vDep) -> vDep.getVulnerability().getCve())));
                    }

                    for (MinimumDepthVulnerability vDep : gavVs) {
                        if (vDep.getVulnerability().equals(v)) {
                            vDep.minDepth = Math.min(vDep.minDepth, resolvedDependency.getDepth());
                            continue nextVulnerability;
                        }
                    }

                    gavVs.add(new MinimumDepthVulnerability(resolvedDependency.getDepth(), v));
                }
            }
        }
    }


    @Value
    public static class MinimumDepthVulnerability {
        @NonFinal
        int minDepth;

        Vulnerability vulnerability;
    }

    private static String stripExtraneousVersionSuffix(String version) {
        if (version.endsWith(".RELEASE")) {
            return version.substring(0, version.length() - ".RELEASE".length());
        }
        return version;
    }
}
