/*
 * SPDX-License-Identifier: Apache-2.0
 * Copyright Red Hat Inc. and Hibernate Authors
 */
package org.hibernate.community.dialect;

import java.util.List;

import org.hibernate.Internal;
import org.hibernate.LockMode;
import org.hibernate.Locking;
import org.hibernate.dialect.DatabaseVersion;
import org.hibernate.dialect.Dialect;
import org.hibernate.dialect.sql.ast.SQLServerSqlAstTranslator;
import org.hibernate.engine.spi.SessionFactoryImplementor;
import org.hibernate.internal.util.collections.Stack;
import org.hibernate.metamodel.mapping.CollectionPart;
import org.hibernate.metamodel.mapping.JdbcMappingContainer;
import org.hibernate.metamodel.mapping.ModelPart;
import org.hibernate.query.IllegalQueryOperationException;
import org.hibernate.query.sqm.tuple.internal.AnonymousTupleTableGroupProducer;
import org.hibernate.query.sqm.ComparisonOperator;
import org.hibernate.query.common.FetchClauseType;
import org.hibernate.sql.ast.Clause;
import org.hibernate.sql.ast.SqlAstJoinType;
import org.hibernate.sql.ast.SqlAstNodeRenderingMode;
import org.hibernate.sql.ast.spi.AbstractSqlAstTranslator;
import org.hibernate.sql.ast.spi.SqlSelection;
import org.hibernate.sql.ast.tree.SqlAstNode;
import org.hibernate.sql.ast.tree.Statement;
import org.hibernate.sql.ast.tree.delete.DeleteStatement;
import org.hibernate.sql.ast.tree.expression.BinaryArithmeticExpression;
import org.hibernate.sql.ast.tree.expression.Expression;
import org.hibernate.sql.ast.tree.expression.Literal;
import org.hibernate.sql.ast.tree.expression.SqlTuple;
import org.hibernate.sql.ast.tree.expression.Summarization;
import org.hibernate.sql.ast.tree.from.DerivedTableReference;
import org.hibernate.sql.ast.tree.from.NamedTableReference;
import org.hibernate.sql.ast.tree.from.TableGroupJoin;
import org.hibernate.sql.ast.tree.from.UnionTableReference;
import org.hibernate.sql.ast.tree.insert.ConflictClause;
import org.hibernate.sql.ast.tree.insert.InsertSelectStatement;
import org.hibernate.sql.ast.tree.predicate.Predicate;
import org.hibernate.sql.ast.tree.select.QueryGroup;
import org.hibernate.sql.ast.tree.select.QueryPart;
import org.hibernate.sql.ast.tree.select.QuerySpec;
import org.hibernate.sql.ast.tree.select.SelectClause;
import org.hibernate.sql.ast.tree.select.SortSpecification;
import org.hibernate.sql.ast.tree.update.UpdateStatement;
import org.hibernate.sql.exec.spi.JdbcOperation;
import org.hibernate.type.SqlTypes;


/**
 * A SQL AST translator for SQL Server.
 *
 * @author Christian Beikov
 */
public class SQLServerLegacySqlAstTranslator<T extends JdbcOperation> extends AbstractSqlAstTranslator<T> {

	private static final String UNION_ALL = " union all ";

	private Predicate lateralPredicate;

	public SQLServerLegacySqlAstTranslator(SessionFactoryImplementor sessionFactory, Statement statement) {
		super( sessionFactory, statement );
	}

	@Override
	protected void visitInsertStatementOnly(InsertSelectStatement statement) {
		if ( statement.getConflictClause() == null || statement.getConflictClause().isDoNothing() ) {
			// Render plain insert statement and possibly run into unique constraint violation
			super.visitInsertStatementOnly( statement );
		}
		else {
			visitInsertStatementEmulateMerge( statement );
			appendSql( ';' );
		}
	}

	@Override
	protected void renderDeleteClause(DeleteStatement statement) {
		appendSql( "delete" );
		final Stack<Clause> clauseStack = getClauseStack();
		try {
			clauseStack.push( Clause.DELETE );
			renderTableReferenceIdentificationVariable( statement.getTargetTable() );
			if ( statement.getFromClause().getRoots().isEmpty() ) {
				appendSql( " from " );
				renderDmlTargetTableExpression( statement.getTargetTable() );
			}
			else {
				visitFromClause( statement.getFromClause() );
			}
		}
		finally {
			clauseStack.pop();
		}
	}

	@Override
	protected void renderUpdateClause(UpdateStatement updateStatement) {
		appendSql( "update" );
		final Stack<Clause> clauseStack = getClauseStack();
		try {
			clauseStack.push( Clause.UPDATE );
			renderTableReferenceIdentificationVariable( updateStatement.getTargetTable() );
		}
		finally {
			clauseStack.pop();
		}
	}

	@Override
	protected void renderDmlTargetTableExpression(NamedTableReference tableReference) {
		super.renderDmlTargetTableExpression( tableReference );
		if ( getClauseStack().getCurrent() != Clause.INSERT ) {
			renderTableReferenceIdentificationVariable( tableReference );
		}
	}

	@Override
	protected void renderFromClauseAfterUpdateSet(UpdateStatement statement) {
		if ( statement.getFromClause().getRoots().isEmpty() ) {
			appendSql( " from " );
			renderDmlTargetTableExpression( statement.getTargetTable() );
		}
		else {
			visitFromClause( statement.getFromClause() );
		}
	}

	@Override
	protected void visitConflictClause(ConflictClause conflictClause) {
		if ( conflictClause != null ) {
			if ( conflictClause.isDoUpdate() && conflictClause.getConstraintName() != null ) {
				throw new IllegalQueryOperationException( "Insert conflict 'do update' clause with constraint name is not supported" );
			}
		}
	}

	@Override
	protected boolean needsRecursiveKeywordInWithClause() {
		return false;
	}

	@Override
	protected void renderTableGroupJoin(TableGroupJoin tableGroupJoin, List<TableGroupJoin> tableGroupJoinCollector) {
		appendSql( WHITESPACE );
		if ( tableGroupJoin.getJoinedGroup().isLateral() ) {
			if ( tableGroupJoin.getJoinType() == SqlAstJoinType.LEFT ) {
				appendSql( "outer apply " );
			}
			else {
				appendSql( "cross apply " );
			}
		}
		else {
			appendSql( tableGroupJoin.getJoinType().getText() );
			appendSql( "join " );
		}

		final Predicate predicate = tableGroupJoin.getPredicate();
		if ( predicate != null && !predicate.isEmpty() ) {
			if ( tableGroupJoin.getJoinedGroup().isLateral() ) {
				// We have to inject the lateral predicate into the sub-query
				final Predicate lateralPredicate = this.lateralPredicate;
				this.lateralPredicate = predicate;
				renderJoinedTableGroup( tableGroupJoin, null, tableGroupJoinCollector );
				this.lateralPredicate = lateralPredicate;
			}
			else {
				renderJoinedTableGroup( tableGroupJoin, predicate, tableGroupJoinCollector );
			}
		}
		else {
			renderJoinedTableGroup( tableGroupJoin, null, tableGroupJoinCollector );
		}
	}

	@Override
	protected void renderDerivedTableReference(DerivedTableReference tableReference) {
		tableReference.accept( this );
	}

	@Override
	public void renderNamedSetReturningFunction(String functionName, List<? extends SqlAstNode> sqlAstArguments, AnonymousTupleTableGroupProducer tupleType, String tableIdentifierVariable, SqlAstNodeRenderingMode argumentRenderingMode) {
		final ModelPart ordinalitySubPart = tupleType.findSubPart( CollectionPart.Nature.INDEX.getName(), null );
		if ( ordinalitySubPart != null ) {
			appendSql( "(select t.*, row_number() over(order by (select 1)) " );
			appendSql( ordinalitySubPart.asBasicValuedModelPart().getSelectionExpression() );
			appendSql( " from " );
			renderSimpleNamedFunction( functionName, sqlAstArguments, argumentRenderingMode );
			append( " t)" );
		}
		else {
			super.renderNamedSetReturningFunction( functionName, sqlAstArguments, tupleType, tableIdentifierVariable, argumentRenderingMode );
		}
	}

	@Override
	protected boolean renderNamedTableReference(NamedTableReference tableReference, LockMode lockMode) {
		final String tableExpression = tableReference.getTableExpression();
		if ( tableReference instanceof UnionTableReference && lockMode != LockMode.NONE && tableExpression.charAt( 0 ) == '(' ) {
			// SQL Server requires to push down the lock hint to the actual table names
			int searchIndex = 0;
			int unionIndex;
			while ( ( unionIndex = tableExpression.indexOf( UNION_ALL, searchIndex ) ) != -1 ) {
				append( tableExpression, searchIndex, unionIndex );
				renderLockHint( lockMode );
				appendSql( UNION_ALL );
				searchIndex = unionIndex + UNION_ALL.length();
			}
			append( tableExpression, searchIndex, tableExpression.length() - 1 );
			renderLockHint( lockMode );
			appendSql( " )" );

			registerAffectedTable( tableReference );
			renderTableReferenceIdentificationVariable( tableReference );
		}
		else {
			super.renderNamedTableReference( tableReference, lockMode );
			renderLockHint( lockMode );
		}
		// Just always return true because SQL Server doesn't support the FOR UPDATE clause
		return true;
	}

	private void renderLockHint(LockMode lockMode) {
		append( determineLockHint( lockMode, getEffectiveLockTimeout( lockMode ), getDialect() ) );
	}

	@Internal
	public static String determineLockHint(LockMode lockMode, int effectiveLockTimeout, Dialect dialect) {
		// NOTE: exposed for tests

		if ( dialect.getVersion().isSameOrAfter( 9 ) ) {
			return SQLServerSqlAstTranslator.determineLockHint( lockMode, effectiveLockTimeout );
		}
		else {
			return switch ( lockMode ) {
				case UPGRADE_NOWAIT, PESSIMISTIC_WRITE, WRITE -> " with (updlock,rowlock)";
				case PESSIMISTIC_READ -> " with (holdlock,rowlock)";
				case UPGRADE_SKIPLOCKED -> " with (updlock,rowlock,readpast)";
				default -> "";
			};
		}
	}

	@Override
	protected LockStrategy determineLockingStrategy(
			QuerySpec querySpec,
			Locking.FollowOn followOnLocking) {
		// No need for follow on locking
		return LockStrategy.CLAUSE;
	}

	protected OffsetFetchClauseMode getOffsetFetchClauseMode(QueryPart queryPart) {
		final DatabaseVersion version = getDialect().getVersion();
		final boolean hasLimit;
		final boolean hasOffset;
		if ( queryPart.isRoot() && hasLimit() ) {
			hasLimit = getLimit().getMaxRows() != null;
			hasOffset = getLimit().getFirstRow() != null;
		}
		else {
			hasLimit = queryPart.getFetchClauseExpression() != null;
			hasOffset = queryPart.getOffsetClauseExpression() != null;
		}
		if ( queryPart instanceof QueryGroup ) {
			// We can't use TOP for set operations
			if ( hasOffset || hasLimit ) {
				if ( version.isBefore( 11 ) || !isRowsOnlyFetchClauseType( queryPart ) ) {
					return OffsetFetchClauseMode.EMULATED;
				}
				else {
					return OffsetFetchClauseMode.STANDARD;
				}
			}

			return null;
		}
		else {
			if ( version.isBefore( 9 ) || !hasOffset ) {
				return hasLimit ? OffsetFetchClauseMode.TOP_ONLY : null;
			}
			else if ( version.isBefore( 11 ) || !isRowsOnlyFetchClauseType( queryPart ) ) {
				return OffsetFetchClauseMode.EMULATED;
			}
			else if ( !queryPart.hasSortSpecifications() && ((QuerySpec) queryPart).getSelectClause().isDistinct() ) {
				// order by (select 0) workaround for offset / fetch does not work when query is distinct
				return OffsetFetchClauseMode.EMULATED;
			}
			else {
				return OffsetFetchClauseMode.STANDARD;
			}
		}
	}

	protected boolean shouldEmulateFetchClause(QueryPart queryPart) {
		// Check if current query part is already row numbering to avoid infinite recursion
		return getQueryPartForRowNumbering() != queryPart && getOffsetFetchClauseMode( queryPart ) == OffsetFetchClauseMode.EMULATED;
	}

	@Override
	public void visitQueryGroup(QueryGroup queryGroup) {
		final Predicate lateralPredicate = this.lateralPredicate;
		if ( lateralPredicate != null ) {
			this.lateralPredicate = null;
			addAdditionalWherePredicate( lateralPredicate );
		}
		if ( shouldEmulateFetchClause( queryGroup ) ) {
			emulateFetchOffsetWithWindowFunctions( queryGroup, !isRowsOnlyFetchClauseType( queryGroup ) );
		}
		else {
			super.visitQueryGroup( queryGroup );
		}
	}

	@Override
	public void visitQuerySpec(QuerySpec querySpec) {
		if ( shouldEmulateFetchClause( querySpec ) ) {
			emulateFetchOffsetWithWindowFunctions( querySpec, !isRowsOnlyFetchClauseType( querySpec ) );
		}
		else {
			super.visitQuerySpec( querySpec );
		}
	}

	@Override
	public void visitSelectClause(SelectClause selectClause) {
		if ( lateralPredicate != null ) {
			addAdditionalWherePredicate( lateralPredicate );
			lateralPredicate = null;
		}
		super.visitSelectClause( selectClause );
	}

	@Override
	protected boolean needsRowsToSkip() {
		return getDialect().getVersion().isBefore( 9 );
	}

	@Override
	protected void visitSqlSelections(SelectClause selectClause) {
		final QuerySpec querySpec = (QuerySpec) getQueryPartStack().getCurrent();
		final OffsetFetchClauseMode offsetFetchClauseMode = getOffsetFetchClauseMode( querySpec );
		if ( offsetFetchClauseMode == OffsetFetchClauseMode.TOP_ONLY ) {
			renderTopClause( querySpec, true, true );
		}
		else if ( offsetFetchClauseMode == OffsetFetchClauseMode.EMULATED ) {
			renderTopClause( querySpec, isRowsOnlyFetchClauseType( querySpec ), true );
		}
		else if ( getQueryPartStack().depth() > 1 && querySpec.hasSortSpecifications()
				&& getQueryPartStack().peek( 1 ) instanceof QueryGroup ) {
			// If the current query spec has a query group parent, no offset/fetch clause, but an order by clause,
			// then we must render "top 100 percent" as that is needed for the SQL to be valid
			appendSql( "top 100 percent " );
		}
		super.visitSqlSelections( selectClause );
	}

	@Override
	protected void renderOrderBy(boolean addWhitespace, List<SortSpecification> sortSpecifications) {
		if ( sortSpecifications != null && !sortSpecifications.isEmpty() ) {
			super.renderOrderBy( addWhitespace, sortSpecifications );
		}
		else if ( getClauseStack().getCurrent() == Clause.OVER ) {
			if ( addWhitespace ) {
				appendSql( ' ' );
			}
			renderEmptyOrderBy();
		}
	}

	protected void renderEmptyOrderBy() {
		// Always need an order by clause: https://blog.jooq.org/2014/05/13/sql-server-trick-circumvent-missing-order-by-clause/
		appendSql( "order by (select 0)" );
	}

	@Override
	public void visitOffsetFetchClause(QueryPart queryPart) {
		if ( !isRowNumberingCurrentQueryPart() ) {
			if ( getDialect().getVersion().isBefore( 9 ) && !queryPart.isRoot() && queryPart.getOffsetClauseExpression() != null ) {
				throw new IllegalArgumentException( "Can't emulate offset clause in subquery" );
			}
			final OffsetFetchClauseMode offsetFetchClauseMode = getOffsetFetchClauseMode( queryPart );
			if ( offsetFetchClauseMode == OffsetFetchClauseMode.STANDARD ) {
				if ( !queryPart.hasSortSpecifications() ) {
					appendSql( ' ' );
					renderEmptyOrderBy();
				}
				final Expression offsetExpression;
				final Expression fetchExpression;
				final FetchClauseType fetchClauseType;
				if ( queryPart.isRoot() && hasLimit() ) {
					prepareLimitOffsetParameters();
					offsetExpression = getOffsetParameter();
					fetchExpression = getLimitParameter();
					fetchClauseType = FetchClauseType.ROWS_ONLY;
				}
				else {
					offsetExpression = queryPart.getOffsetClauseExpression();
					fetchExpression = queryPart.getFetchClauseExpression();
					fetchClauseType = queryPart.getFetchClauseType();
				}
				if ( offsetExpression == null ) {
					appendSql( " offset 0 rows" );
				}
				else {
					renderOffset( offsetExpression, true );
				}

				if ( fetchExpression != null ) {
					renderFetch( fetchExpression, null, fetchClauseType );
				}
			}
		}
	}

	@Override
	protected void renderComparison(Expression lhs, ComparisonOperator operator, Expression rhs) {
		final JdbcMappingContainer lhsExpressionType = lhs.getExpressionType();
		if ( lhsExpressionType != null && lhsExpressionType.getJdbcTypeCount() == 1
				&& lhsExpressionType.getSingleJdbcMapping().getJdbcType().getDdlTypeCode() == SqlTypes.SQLXML ) {
			// In SQL Server, XMLTYPE is not "comparable", so we have to cast the two parts to varchar for this purpose
			switch ( operator ) {
				case DISTINCT_FROM:
					if ( !getDialect().supportsDistinctFromPredicate() ) {
						appendSql( "not " );
					}
				case NOT_DISTINCT_FROM: {
					if ( !getDialect().supportsDistinctFromPredicate() ) {
						appendSql( "exists (select cast(" );
						getClauseStack().push( Clause.SELECT );
						visitSqlSelectExpression( lhs );
						appendSql( " as nvarchar(max))" );
						appendSql( getFromDualForSelectOnly() );
						appendSql( " intersect select cast(" );
						visitSqlSelectExpression( rhs );
						appendSql( " as nvarchar(max))" );
						appendSql( getFromDualForSelectOnly() );
						getClauseStack().pop();
						appendSql( CLOSE_PARENTHESIS );
						return;
					}
				}
				case EQUAL:
				case NOT_EQUAL:
					appendSql( "cast(" );
					lhs.accept( this );
					appendSql( " as nvarchar(max))" );
					appendSql( operator.sqlText() );
					appendSql( "cast(" );
					rhs.accept( this );
					appendSql( " as nvarchar(max))" );
					return;
				default:
					// Fall through
					break;
			}
		}
		if ( getDialect().supportsDistinctFromPredicate() ) {
			renderComparisonStandard( lhs, operator, rhs );
		}
		else {
			renderComparisonEmulateIntersect( lhs, operator, rhs );
		}
	}

	@Override
	protected void renderSelectTupleComparison(
			List<SqlSelection> lhsExpressions,
			SqlTuple tuple,
			ComparisonOperator operator) {
		emulateSelectTupleComparison( lhsExpressions, tuple.getExpressions(), operator, true );
	}

	@Override
	protected void renderPartitionItem(Expression expression) {
		if ( expression instanceof Literal ) {
			appendSql( "()" );
		}
		else if ( expression instanceof Summarization ) {
			Summarization summarization = (Summarization) expression;
			renderCommaSeparated( summarization.getGroupings() );
			appendSql( " with " );
			appendSql( summarization.getKind().sqlText() );
		}
		else {
			expression.accept( this );
		}
	}

	@Override
	public void visitBinaryArithmeticExpression(BinaryArithmeticExpression arithmeticExpression) {
		appendSql( OPEN_PARENTHESIS );
		visitArithmeticOperand( arithmeticExpression.getLeftHandOperand() );
		appendSql( arithmeticExpression.getOperator().getOperatorSqlTextString() );
		visitArithmeticOperand( arithmeticExpression.getRightHandOperand() );
		appendSql( CLOSE_PARENTHESIS );
	}

	enum OffsetFetchClauseMode {
		STANDARD,
		TOP_ONLY,
		EMULATED;
	}

	@Override
	protected void renderStringContainsExactlyPredicate(Expression haystack, Expression needle) {
		// SQL Server ignores NUL characters in string on case-insensitive collations, so we force a binary collation.
		// This is needed for the emulation of cycle detection in recursive queries
		appendSql( "charindex(" );
		needle.accept( this );
		appendSql( " collate Latin1_General_100_BIN2," );
		haystack.accept( this );
		append( ")>0" );
	}
}
