Skip to content

Commit

Permalink
Merge pull request #1752 from riganti/lambda-infer-any-delegate
Browse files Browse the repository at this point in the history
lambda inferrer: support custom delegates
  • Loading branch information
tomasherceg committed Feb 8, 2024
2 parents 33a7fec + dce2157 commit b4aab99
Show file tree
Hide file tree
Showing 7 changed files with 303 additions and 43 deletions.
Expand Up @@ -457,6 +457,8 @@ protected override Expression VisitLambda(LambdaBindingParserNode node)
for (var paramIndex = 0; paramIndex < typeInferenceData.Parameters!.Length; paramIndex++)
{
var currentParamType = typeInferenceData.Parameters[paramIndex];
if (currentParamType.ContainsGenericParameters)
throw new BindingCompilationException($"Internal bug: lambda parameter still contains generic arguments: parameters[{paramIndex}] = {currentParamType.ToCode()}", node);
node.ParameterExpressions[paramIndex].SetResolvedType(currentParamType);
}
}
Expand Down Expand Up @@ -506,26 +508,90 @@ protected override Expression VisitLambdaParameter(LambdaParameterBindingParserN

private Expression CreateLambdaExpression(Expression body, ParameterExpression[] parameters, Type? delegateType)
{
if (delegateType != null && delegateType.Namespace == "System")
if (delegateType is null || delegateType == typeof(object) || delegateType == typeof(Delegate))
// Assume delegate is a System.Func<...>
return Expression.Lambda(body, parameters);

if (!delegateType.IsDelegate(out var invokeMethod))
throw new DotvvmCompilationException($"Cannot create lambda function, type '{delegateType.ToCode()}' is not a delegate type.");

if (invokeMethod.ReturnType == typeof(void))
{
// We must validate that lambda body contains a valid statement
if ((body.NodeType != ExpressionType.Default) && (body.NodeType != ExpressionType.Block) && (body.NodeType != ExpressionType.Call) && (body.NodeType != ExpressionType.Assign))
throw new DotvvmCompilationException($"Only method invocations and assignments can be used as statements.");

// Make sure the result type will be void by adding an empty expression
body = Expression.Block(body, Expression.Empty());
}

// convert body result to the delegate return type
if (invokeMethod.ReturnType.ContainsGenericParameters)
{
if (delegateType.Name == "Action" || delegateType.Name == $"Action`{parameters.Length}")
if (invokeMethod.ReturnType.IsGenericType)
{
// We must validate that lambda body contains a valid statement
if ((body.NodeType != ExpressionType.Default) && (body.NodeType != ExpressionType.Block) && (body.NodeType != ExpressionType.Call) && (body.NodeType != ExpressionType.Assign))
throw new DotvvmCompilationException($"Only method invocations and assignments can be used as statements.");
// no fancy implicit conversions are supported, only inheritance
if (!ReflectionUtils.IsAssignableToGenericType(body.Type, invokeMethod.ReturnType.GetGenericTypeDefinition(), out var bodyReturnType))
{
throw new DotvvmCompilationException($"Cannot convert lambda function body of type '{body.Type.ToCode()}' to the delegate return type '{invokeMethod.ReturnType.ToCode()}'.");
}
else
{
body = Expression.Convert(body, bodyReturnType);
}
}
else
{
// fine, we will unify it in the next step

// Make sure the result type will be void by adding an empty expression
return Expression.Lambda(Expression.Block(body, Expression.Empty()), parameters);
// Some complex conversions like Tuple<T, List<object>> -> Tuple<T, IEnumerable<T2>>
// will fail, but we don't have to support everything
}
else if (delegateType.Name == "Predicate`1")
}
else
{
body = TypeConversion.EnsureImplicitConversion(body, invokeMethod.ReturnType);
}

if (delegateType.ContainsGenericParameters)
{
var delegateTypeDef = delegateType.GetGenericTypeDefinition();
// The delegate is either purely generic (Func<T, T>) or only some of the generic arguments are known (Func<T, bool>)
// initialize generic args with the already known types
var genericArgs =
delegateTypeDef.GetGenericArguments().Zip(
delegateType.GetGenericArguments(),
(param, argument) => new KeyValuePair<Type, Type>(param, argument)
)
.Where(p => p.Value != p.Key)
.ToDictionary(p => p.Key, p => p.Value);

var delegateParameters = invokeMethod.GetParameters();
for (int i = 0; i < parameters.Length; i++)
{
if (!ReflectionUtils.TryUnifyGenericTypes(delegateParameters[i].ParameterType, parameters[i].Type, genericArgs))
{
throw new DotvvmCompilationException($"Could not match lambda function parameter '{parameters[i].Type.ToCode()} {parameters[i].Name}' to delegate parameter '{delegateParameters[i].ParameterType.ToCode()} {delegateParameters[i].Name}'.");
}
}
if (!ReflectionUtils.TryUnifyGenericTypes(invokeMethod.ReturnType, body.Type, genericArgs))
{
var type = delegateType.GetGenericTypeDefinition().MakeGenericType(parameters.Single().Type);
return Expression.Lambda(type, body, parameters);
throw new DotvvmCompilationException($"Could not match lambda function return type '{body.Type.ToCode()}' to delegate return type '{invokeMethod.ReturnType.ToCode()}'.");
}
ReflectionUtils.ExpandUnifiedTypes(genericArgs);

if (!delegateTypeDef.GetGenericArguments().All(a => genericArgs.TryGetValue(a, out var v) && !v.ContainsGenericParameters))
{
var missingGenericArgs = delegateTypeDef.GetGenericArguments().Where(genericArg => !genericArgs.ContainsKey(genericArg) || genericArgs[genericArg].ContainsGenericParameters);
throw new DotvvmCompilationException($"Could not infer all generic arguments ({string.Join(", ", missingGenericArgs)}) of delegate type '{delegateType.ToCode()}' from lambda expression '({string.Join(", ", parameters.Select(p => $"{p.Type.ToCode()} {p.Name}"))}) => ...'.");
}

delegateType = delegateTypeDef.MakeGenericType(
delegateTypeDef.GetGenericArguments().Select(genericParam => genericArgs[genericParam]).ToArray()
);
}

// Assume delegate is a System.Func<...>
return Expression.Lambda(body, parameters);
return Expression.Lambda(delegateType, body, parameters);
}

protected override Expression VisitBlock(BlockBindingParserNode node)
Expand Down
Expand Up @@ -9,15 +9,15 @@ internal class InfererContext
{
public MethodGroupExpression? Target { get; set; }
public Expression[] Arguments { get; set; }
public Dictionary<string, Type> Generics { get; set; }
public Dictionary<Type, Type> Generics { get; set; }
public int CurrentArgumentIndex { get; set; }
public bool IsExtensionCall { get; set; }

public InfererContext(MethodGroupExpression? target, int argsCount)
{
this.Target = target;
this.Arguments = new Expression[argsCount];
this.Generics = new Dictionary<string, Type>();
this.Generics = new();
}
}
}
Expand Up @@ -94,39 +94,32 @@ private bool TryMatchDelegate(InfererContext? context, int argsCount, Type deleg
if (delegateParameters.Length != argsCount)
return false;

var generics = (context != null) ? context.Generics : new Dictionary<string, Type>();
if (!TryInstantiateDelegateParameters(delegateType, argsCount, generics, out parameters))
var generics = (context != null) ? context.Generics : new Dictionary<Type, Type>();
if (!TryInstantiateDelegateParameters(delegateParameters.Select(p => p.ParameterType).ToArray(), argsCount, generics, out parameters))
return false;

return true;
}

private bool TryInstantiateDelegateParameters(Type generic, int argsCount, Dictionary<string, Type> generics, [NotNullWhen(true)] out Type[]? instantiation)
private bool TryInstantiateDelegateParameters(Type[] delegateParameters, int argsCount, Dictionary<Type, Type> generics, [NotNullWhen(true)] out Type[]? instantiation)
{
var genericArgs = generic.GetGenericArguments();
var substitutions = new Type[argsCount];

for (var argIndex = 0; argIndex < argsCount; argIndex++)
{
var currentArg = genericArgs[argIndex];
var currentArg = delegateParameters[argIndex];
var assignedArg = ReflectionUtils.AssignGenericParameters(currentArg, generics);

if (!currentArg.IsGenericParameter)
{
// This is a known type
substitutions[argIndex] = currentArg;
}
else if (currentArg.IsGenericParameter && generics.ContainsKey(currentArg.Name))
{
// This is a generic parameter
// But we already inferred its type
substitutions[argIndex] = generics[currentArg.Name];
}
else
if (assignedArg.ContainsGenericParameters)
{
// This is an unknown type
instantiation = null;
return false;
}
else
{
substitutions[argIndex] = assignedArg;
}
}

instantiation = substitutions;
Expand Down
12 changes: 6 additions & 6 deletions src/Framework/Framework/Compilation/Inference/TypeInferer.cs
Expand Up @@ -74,11 +74,11 @@ private void RefineCandidates(int index)
return;

var newCandidates = new List<MethodInfo>();
var newInstantiations = new Dictionary<string, HashSet<Type>>();
var newInstantiations = new Dictionary<Type, HashSet<Type>>();

// Check if we can remove some candidates
// Also try to infer generics based on provided argument
var tempInstantiations = new Dictionary<string, Type>();
var tempInstantiations = new Dictionary<Type, Type>();
foreach (var candidate in context.Target.Candidates!.Where(c => c.GetParameters().Length > index))
{
tempInstantiations.Clear();
Expand All @@ -87,12 +87,12 @@ private void RefineCandidates(int index)

if (parameterType.IsGenericParameter)
{
tempInstantiations.Add(parameterType.Name, argumentType);
tempInstantiations.Add(parameterType, argumentType);
}
else if (parameterType.ContainsGenericParameters)
{
// Check if we already inferred instantiation for these generics
if (!parameterType.GetGenericArguments().Any(param => !context.Generics.ContainsKey(param.Name)))
if (!parameterType.GetGenericArguments().Any(param => !context.Generics.ContainsKey(param)))
continue;

// Try to infer instantiation based on given argument
Expand All @@ -119,15 +119,15 @@ private void RefineCandidates(int index)
context.Target.Candidates = newCandidates;
}

private bool TryInferInstantiation(Type generic, Type concrete, Dictionary<string, Type> generics)
private bool TryInferInstantiation(Type generic, Type concrete, Dictionary<Type, Type> generics)
{
if (generic == concrete)
return true;

if (generic.IsGenericParameter)
{
// We found the instantiation
generics.Add(generic.Name, concrete);
generics.Add(generic, concrete);
return true;
}
else if (ReflectionUtils.IsEnumerable(generic))
Expand Down
125 changes: 119 additions & 6 deletions src/Framework/Framework/Utils/ReflectionUtils.cs
Expand Up @@ -99,14 +99,17 @@ public static IEnumerable<MethodInfo> GetAllMethods(this Type type, BindingFlags
/// </summary>
public static bool IsAssignableToGenericType(this Type givenType, Type genericType, [NotNullWhen(returnValue: true)] out Type? commonType)
{
var interfaceTypes = givenType.GetInterfaces();

foreach (var it in interfaceTypes)
if (genericType.IsInterface)
{
if (it.IsGenericType && it.GetGenericTypeDefinition() == genericType)
var interfaceTypes = givenType.GetInterfaces();

foreach (var it in interfaceTypes)
{
commonType = it;
return true;
if (it.IsGenericType && it.GetGenericTypeDefinition() == genericType)
{
commonType = it;
return true;
}
}
}

Expand Down Expand Up @@ -665,5 +668,115 @@ public static IEnumerable<Type> GetBaseTypesAndInterfaces(Type type)
type = baseType;
}
}


internal static bool TryUnifyGenericTypes(Type a, Type b, Dictionary<Type, Type> genericAssignment)
{
if (a == b)
return true;

if (a.IsGenericParameter)
{
if (genericAssignment.ContainsKey(a))
return TryUnifyGenericTypes(genericAssignment[a], b, genericAssignment);

genericAssignment.Add(a, b);
return true;
}
else if (b.IsGenericParameter)
{
if (genericAssignment.ContainsKey(b))
return TryUnifyGenericTypes(a, genericAssignment[b], genericAssignment);

genericAssignment.Add(b, a);
return true;
}
else if (a.IsGenericType && b.IsGenericType)
{
if (a.GetGenericTypeDefinition() != b.GetGenericTypeDefinition())
return false;

var aArgs = a.GetGenericArguments();
var bArgs = b.GetGenericArguments();
if (aArgs.Length != bArgs.Length)
return false;

for (var i = 0; i < aArgs.Length; i++)
{
if (!TryUnifyGenericTypes(aArgs[i], bArgs[i], genericAssignment))
return false;
}

return true;
}
else
{
return false;
}
}

internal static void ExpandUnifiedTypes(Dictionary<Type, Type> genericAssignment)
{
var iteration = 0;
bool dirty;
do
{
dirty = false;
iteration++;
if (iteration > 100)
throw new Exception("Too much recursion in ExpandUnifiedTypes");

foreach (var (key, value) in genericAssignment.ToArray())
{
var expanded = AssignGenericParameters(value, genericAssignment);
if (expanded != value)
{
genericAssignment[key] = expanded;
dirty = true;
}
}
}
while (dirty);
}

internal static Type AssignGenericParameters(Type t, IReadOnlyDictionary<Type, Type> genericAssignment)
{
if (!t.ContainsGenericParameters)
return t;

if (t.IsGenericParameter)
{
if (genericAssignment.TryGetValue(t, out var result))
return result;
else
return t;
}
else if (t.IsGenericType)
{
var args = t.GetGenericArguments();
for (var i = 0; i < args.Length; i++)
{
args[i] = AssignGenericParameters(args[i], genericAssignment);
}
if (args.SequenceEqual(t.GetGenericArguments()))
return t;
else
return t.GetGenericTypeDefinition().MakeGenericType(args);
}
else if (t.HasElementType)
{
var el = AssignGenericParameters(t.GetElementType()!, genericAssignment);
if (el == t.GetElementType())
return t;
else if (t.IsArray)
return el.MakeArrayType(t.GetArrayRank());
else
throw new NotSupportedException();
}
else
{
return t;
}
}
}
}

0 comments on commit b4aab99

Please sign in to comment.