Skip to content

Commit

Permalink
Translate non-aggregate string.Join
Browse files Browse the repository at this point in the history
  • Loading branch information
roji committed Aug 26, 2022
1 parent 1c71c96 commit 5283633
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ public class SqlServerSqlTranslatingExpressionVisitor : RelationalSqlTranslating
ExpressionType.Modulo
};

private static readonly MethodInfo StringJoinMethodInfo
= typeof(string).GetRuntimeMethod(nameof(string.Join), new[] { typeof(string), typeof(string[]) })!;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down Expand Up @@ -152,6 +155,97 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)
return base.VisitUnary(unaryExpression);
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
var translation = base.VisitMethodCall(methodCallExpression);

if (translation != QueryCompilationContext.NotTranslatedExpression)
{
return translation;
}

if (methodCallExpression.Method == StringJoinMethodInfo)
{
if (methodCallExpression.Arguments[1] is not NewArrayExpression newArrayExpression)
{
return QueryCompilationContext.NotTranslatedExpression;
}

var sqlArguments = new SqlExpression[newArrayExpression.Expressions.Count + 1];

if (TranslationFailed(methodCallExpression.Arguments[0], Visit(methodCallExpression.Arguments[0]), out var sqlDelimiter))
{
return QueryCompilationContext.NotTranslatedExpression;
}

sqlArguments[0] = sqlDelimiter!;

var isUnicode = sqlDelimiter!.TypeMapping?.IsUnicode == true;

for (var i = 0; i < newArrayExpression.Expressions.Count; i++)
{
var argument = newArrayExpression.Expressions[i];
if (TranslationFailed(argument, Visit(argument), out var sqlArgument))
{
return QueryCompilationContext.NotTranslatedExpression;
}

// CONCAT_WS returns a type with a length that varies based on actual inputs (i.e. the sum of all argument lengths, plus
// the length needed for the delimiters). We don't know about parameter values here, so we always return max.
// We do vary return varchar(max) or nvarchar(max) based on whether we saw any nvarchar mapping.
if (sqlArgument!.TypeMapping?.IsUnicode == true)
{
isUnicode = true;
}

// CONCAT_WS filters out nulls, but string.Join treats them as empty strings; coalesce unless we know we have a non-nullable
// argument.
sqlArguments[i + 1] = sqlArgument switch
{
ColumnExpression { IsNullable: false } => sqlArgument,
SqlConstantExpression constantExpression => constantExpression.Value is null
? new SqlConstantExpression(Expression.Constant(string.Empty, typeof(string)), null)
: constantExpression,
_ => Dependencies.SqlExpressionFactory.Coalesce(
sqlArgument,
Dependencies.SqlExpressionFactory.Constant(string.Empty, typeof(string)))
};
}

// CONCAT_WS never returns null; a null delimiter is interpreted as an empty string, and null arguments are skipped
// (but we coalesce them above in any case).
return Dependencies.SqlExpressionFactory.Function(
"CONCAT_WS",
sqlArguments,
nullable: false,
argumentsPropagateNullability: new bool[sqlArguments.Length],
methodCallExpression.Method.ReturnType,
Dependencies.TypeMappingSource.FindMapping(isUnicode ? "nvarchar(max)" : "varchar(max)"));
}

return QueryCompilationContext.NotTranslatedExpression;
}

private static string? GetProviderType(SqlExpression expression)
=> expression.TypeMapping?.StoreType;

[DebuggerStepThrough]
private static bool TranslationFailed(Expression? original, Expression? translation, out SqlExpression? castTranslation)
{
if (original != null
&& translation is not SqlExpression)
{
castTranslation = null;
return true;
}

castTranslation = translation as SqlExpression;
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,18 @@ public virtual Task String_Join_over_nullable_column(bool async)
a.Regions.Split("|").OrderBy(id => id).ToArray());
});

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task String_Join_non_aggregate(bool async)
{
var foo = "foo";

return AssertQuery(
async,
ss => ss.Set<Customer>().Where(c => string.Join("|", c.CompanyName, foo, "bar") == "Around the Horn|foo|bar"),
entryCount: 1);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task String_Concat(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,19 @@ public override async Task String_Join_with_ordering(bool async)
GROUP BY [c].[City]");
}

[SqlServerCondition(SqlServerCondition.SupportsFunctions2017)]
public override async Task String_Join_non_aggregate(bool async)
{
await base.String_Join_non_aggregate(async);

AssertSql(
@"@__foo_0='foo' (Size = 4000)
SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE CONCAT_WS(N'|', COALESCE([c].[CompanyName], N''), COALESCE(@__foo_0, N''), N'bar') = N'Around the Horn|foo|bar'");
}

[SqlServerCondition(SqlServerCondition.SupportsFunctions2017)]
public override async Task String_Concat(bool async)
{
Expand Down

0 comments on commit 5283633

Please sign in to comment.