Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move nargs/isva to CodeInfo #54341

Merged
merged 1 commit into from
May 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
45 changes: 18 additions & 27 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,12 @@ function conditional_argtype(𝕃ᵢ::AbstractLattice, @nospecialize(rt), @nospe
if isa(rt, InterConditional) && rt.slot == i
return rt
else
thentype = elsetype = tmeet(𝕃ᵢ, widenslotwrapper(argtypes[i]), fieldtype(sig, i))
argt = widenslotwrapper(argtypes[i])
if isvarargtype(argt)
@assert fieldcount(sig) == i
argt = unwrapva(argt)
end
thentype = elsetype = tmeet(𝕃ᵢ, argt, fieldtype(sig, i))
Comment on lines +482 to +487
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check wasn't needed previously, but has it become necessary with this PR? argtypes used by abstract_apply can certainly include Vararg, but I find it curious why the compiler functioned well without this check before.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect it was theoretically reachable, but the change to apply this to the pre-processed signature made it more likely. I didn't bother investigating.

condval = maybe_extract_const_bool(rt)
condval === true && (elsetype = Bottom)
condval === false && (thentype = Bottom)
Expand Down Expand Up @@ -986,15 +991,12 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter,
# N.B. remarks are emitted within `const_prop_entry_heuristic`
return nothing
end
nargs::Int = method.nargs
method.isva && (nargs -= 1)
length(arginfo.argtypes) < nargs && return nothing
if !const_prop_argument_heuristic(interp, arginfo, sv)
add_remark!(interp, sv, "[constprop] Disabled by argument and rettype heuristics")
return nothing
end
all_overridden = is_all_overridden(interp, arginfo, sv)
if !force && !const_prop_function_heuristic(interp, f, arginfo, nargs, all_overridden, sv)
if !force && !const_prop_function_heuristic(interp, f, arginfo, all_overridden, sv)
add_remark!(interp, sv, "[constprop] Disabled by function heuristic")
return nothing
end
Expand Down Expand Up @@ -1113,9 +1115,9 @@ function force_const_prop(interp::AbstractInterpreter, @nospecialize(f), method:
end

function const_prop_function_heuristic(interp::AbstractInterpreter, @nospecialize(f),
arginfo::ArgInfo, nargs::Int, all_overridden::Bool, sv::AbsIntState)
arginfo::ArgInfo, all_overridden::Bool, sv::AbsIntState)
argtypes = arginfo.argtypes
if nargs > 1
if length(argtypes) > 1
𝕃ᵢ = typeinf_lattice(interp)
if istopfunction(f, :getindex) || istopfunction(f, :setindex!)
arrty = argtypes[2]
Expand Down Expand Up @@ -1274,7 +1276,7 @@ function const_prop_call(interp::AbstractInterpreter,
end
overridden_by_const = falses(length(argtypes))
for i = 1:length(argtypes)
if argtypes[i] !== cache_argtypes[i]
if argtypes[i] !== argtype_by_index(cache_argtypes, i)
overridden_by_const[i] = true
end
end
Expand Down Expand Up @@ -1349,20 +1351,6 @@ function matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance,
end
given_argtypes[i] = widenslotwrapper(argtype)
end
if condargs !== nothing
given_argtypes = let condargs=condargs
va_process_argtypes(𝕃, given_argtypes, mi) do isva_given_argtypes::Vector{Any}, last::Int
# invalidate `Conditional` imposed on varargs
for (slotid, i) in condargs
if slotid ≥ last && (1 ≤ i ≤ length(isva_given_argtypes)) # `Conditional` is already widened to vararg-tuple otherwise
isva_given_argtypes[i] = widenconditional(isva_given_argtypes[i])
end
end
end
end
else
given_argtypes = va_process_argtypes(𝕃, given_argtypes, mi)
end
Comment on lines -1352 to -1365
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this change, we might be able to switch the weird interface of [Simple|Widened|Conditional]Argtypes (if you can even call it an 'interface') to something better like AbstractLattice.

return pick_const_args!(𝕃, given_argtypes, cache_argtypes)
end

Expand Down Expand Up @@ -1721,7 +1709,7 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, si::
return CallMeta(res, exct, effects, retinfo)
end

function argtype_by_index(argtypes::Vector{Any}, i::Int)
function argtype_by_index(argtypes::Vector{Any}, i::Integer)
n = length(argtypes)
na = argtypes[n]
if isvarargtype(na)
Expand Down Expand Up @@ -2890,12 +2878,12 @@ end
struct BestguessInfo{Interp<:AbstractInterpreter}
interp::Interp
bestguess
nargs::Int
nargs::UInt
slottypes::Vector{Any}
changes::VarTable
function BestguessInfo(interp::Interp, @nospecialize(bestguess), nargs::Int,
function BestguessInfo(interp::Interp, @nospecialize(bestguess), nargs::UInt,
slottypes::Vector{Any}, changes::VarTable) where Interp<:AbstractInterpreter
new{Interp}(interp, bestguess, nargs, slottypes, changes)
new{Interp}(interp, bestguess, Int(nargs), slottypes, changes)
end
end

Expand Down Expand Up @@ -2970,7 +2958,7 @@ end
# pick up the first "interesting" slot, convert `rt` to its `Conditional`
# TODO: ideally we want `Conditional` and `InterConditional` to convey
# constraints on multiple slots
for slot_id = 1:info.nargs
for slot_id = 1:Int(info.nargs)
rt = bool_rt_to_conditional(rt, slot_id, info)
rt isa InterConditional && break
end
Expand All @@ -2981,6 +2969,9 @@ end
⊑ᵢ = ⊑(typeinf_lattice(info.interp))
old = info.slottypes[slot_id]
new = widenslotwrapper(info.changes[slot_id].typ) # avoid nested conditional
if isvarargtype(old) || isvarargtype(new)
return rt
end
if new ⊑ᵢ old && !(old ⊑ᵢ new)
if isa(rt, Const)
val = rt.val
Expand Down
186 changes: 91 additions & 95 deletions base/compiler/inferenceresult.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,56 @@ function matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance,
for i = 1:length(argtypes)
given_argtypes[i] = widenslotwrapper(argtypes[i])
end
given_argtypes = va_process_argtypes(𝕃, given_argtypes, mi)
return pick_const_args!(𝕃, given_argtypes, cache_argtypes)
end

function pick_const_arg(𝕃::AbstractLattice, @nospecialize(given_argtype), @nospecialize(cache_argtype))
if !is_argtype_match(𝕃, given_argtype, cache_argtype, false)
# prefer the argtype we were given over the one computed from `mi`
if (isa(given_argtype, PartialStruct) && isa(cache_argtype, Type) &&
!⊏(𝕃, given_argtype, cache_argtype))
# if the type information of this `PartialStruct` is less strict than
# declared method signature, narrow it down using `tmeet`
given_argtype = tmeet(𝕃, given_argtype, cache_argtype)
end
else
given_argtype = cache_argtype
end
return given_argtype
end

function pick_const_args!(𝕃::AbstractLattice, given_argtypes::Vector{Any}, cache_argtypes::Vector{Any})
nargtypes = length(given_argtypes)
@assert nargtypes == length(cache_argtypes) #= == nargs =# "invalid `given_argtypes` for `mi`"
for i = 1:nargtypes
given_argtype = given_argtypes[i]
cache_argtype = cache_argtypes[i]
if !is_argtype_match(𝕃, given_argtype, cache_argtype, false)
# prefer the argtype we were given over the one computed from `mi`
if (isa(given_argtype, PartialStruct) && isa(cache_argtype, Type) &&
!⊏(𝕃, given_argtype, cache_argtype))
# if the type information of this `PartialStruct` is less strict than
# declared method signature, narrow it down using `tmeet`
given_argtypes[i] = tmeet(𝕃, given_argtype, cache_argtype)
end
if length(given_argtypes) == 0 || length(cache_argtypes) == 0
return Any[]
end
Comment on lines +46 to +48
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this check?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Toplevel inference has no arguments. I think that used to bypass this path.

given_va = given_argtypes[end]
cache_va = cache_argtypes[end]
if isvarargtype(given_va)
ngiven = length(given_argtypes)
va = unwrapva(given_va)
if isvarargtype(cache_va)
# Process the common prefix, then join
nprocessargs = max(length(given_argtypes)-1, length(cache_argtypes)-1)
resize!(given_argtypes, nprocessargs+1)
given_argtypes[end] = Vararg{pick_const_arg(𝕃, unwrapva(given_va), unwrapva(cache_va))}
else
given_argtypes[i] = cache_argtype
nprocessargs = length(cache_argtypes)
resize!(given_argtypes, nprocessargs)
end
for i = ngiven:nprocessargs
given_argtypes[i] = va
end
elseif isvarargtype(cache_va)
nprocessargs = length(given_argtypes)
else
@assert length(given_argtypes) == length(cache_argtypes)
nprocessargs = length(given_argtypes)
end
for i = 1:nprocessargs
given_argtype = given_argtypes[i]
cache_argtype = argtype_by_index(cache_argtypes, i)
given_argtype = pick_const_arg(𝕃, given_argtype, cache_argtype)
given_argtypes[i] = given_argtype
end
return given_argtypes
end
Expand All @@ -60,25 +89,33 @@ function is_argtype_match(𝕃::AbstractLattice,
end
end

va_process_argtypes(𝕃::AbstractLattice, given_argtypes::Vector{Any}, mi::MethodInstance) =
va_process_argtypes(Returns(nothing), 𝕃, given_argtypes, mi)
function va_process_argtypes(@specialize(va_handler!), 𝕃::AbstractLattice, given_argtypes::Vector{Any}, mi::MethodInstance)
def = mi.def::Method
isva = def.isva
nargs = Int(def.nargs)
if isva || isvarargtype(given_argtypes[end])
isva_given_argtypes = Vector{Any}(undef, nargs)
function va_process_argtypes(𝕃::AbstractLattice, given_argtypes::Vector{Any}, nargs::UInt, isva::Bool)
if isva || (!isempty(given_argtypes) && isvarargtype(given_argtypes[end]))
isva_given_argtypes = Vector{Any}(undef, Int(nargs))
for i = 1:(nargs-isva)
isva_given_argtypes[i] = argtype_by_index(given_argtypes, i)
newarg = argtype_by_index(given_argtypes, i)
if isva && has_conditional(𝕃) && isa(newarg, Conditional)
if newarg.slot > (nargs-isva)
newarg = widenconditional(newarg)
end
end
isva_given_argtypes[i] = newarg
end
if isva
if length(given_argtypes) < nargs && isvarargtype(given_argtypes[end])
last = length(given_argtypes)
else
last = nargs
if has_conditional(𝕃)
for i = last:length(given_argtypes)
newarg = given_argtypes[i]
if isa(newarg, Conditional) && newarg.slot > (nargs-isva)
given_argtypes[i] = widenconditional(newarg)
end
end
end
end
isva_given_argtypes[nargs] = tuple_tfunc(𝕃, given_argtypes[last:end])
va_handler!(isva_given_argtypes, last)
end
return isva_given_argtypes
end
Expand All @@ -87,84 +124,44 @@ function va_process_argtypes(@specialize(va_handler!), 𝕃::AbstractLattice, gi
end

function most_general_argtypes(method::Union{Method,Nothing}, @nospecialize(specTypes))
toplevel = method === nothing
isva = !toplevel && method.isva
mi_argtypes = Any[(unwrap_unionall(specTypes)::DataType).parameters...]
nargs::Int = toplevel ? 0 : method.nargs
cache_argtypes = Vector{Any}(undef, nargs)
# First, if we're dealing with a varargs method, then we set the last element of `args`
# to the appropriate `Tuple` type or `PartialStruct` instance.
mi_argtypes_length = length(mi_argtypes)
if !toplevel && isva
if specTypes::Type == Tuple
mi_argtypes = Any[Any for i = 1:nargs]
if nargs > 1
mi_argtypes[end] = Tuple
end
vargtype = Tuple
else
if nargs > mi_argtypes_length
va = mi_argtypes[mi_argtypes_length]
if isvarargtype(va)
new_va = rewrap_unionall(unconstrain_vararg_length(va), specTypes)
vargtype = Tuple{new_va}
else
vargtype = Tuple{}
end
else
vargtype_elements = Any[]
for i in nargs:mi_argtypes_length
p = mi_argtypes[i]
p = unwraptv(isvarargtype(p) ? unconstrain_vararg_length(p) : p)
push!(vargtype_elements, elim_free_typevars(rewrap_unionall(p, specTypes)))
end
for i in 1:length(vargtype_elements)
atyp = vargtype_elements[i]
if issingletontype(atyp)
# replace singleton types with their equivalent Const object
vargtype_elements[i] = Const(atyp.instance)
elseif isconstType(atyp)
vargtype_elements[i] = Const(atyp.parameters[1])
end
end
vargtype = tuple_tfunc(fallback_lattice, vargtype_elements)
end
end
cache_argtypes[nargs] = vargtype
nargs -= 1
nargtypes = length(mi_argtypes)
nargs = isa(method, Method) ? method.nargs : 0
if length(mi_argtypes) < nargs && isvarargtype(mi_argtypes[end])
resize!(mi_argtypes, nargs)
end
# Now, we propagate type info from `mi_argtypes` into `cache_argtypes`, improving some
# type info as we go (where possible). Note that if we're dealing with a varargs method,
# we already handled the last element of `cache_argtypes` (and decremented `nargs` so that
# we don't overwrite the result of that work here).
if mi_argtypes_length > 0
tail_index = nargtypes = min(mi_argtypes_length, nargs)
local lastatype
for i = 1:nargtypes
atyp = mi_argtypes[i]
if i == nargtypes && isvarargtype(atyp)
atyp = unwrapva(atyp)
tail_index -= 1
end
atyp = unwraptv(atyp)
if issingletontype(atyp)
# replace singleton types with their equivalent Const object
atyp = Const(atyp.instance)
elseif isconstType(atyp)
atyp = Const(atyp.parameters[1])
else
atyp = elim_free_typevars(rewrap_unionall(atyp, specTypes))
end
i == nargtypes && (lastatype = atyp)
cache_argtypes[i] = atyp
tail_index = min(nargtypes, nargs)
local lastatype
for i = 1:nargtypes
atyp = mi_argtypes[i]
wasva = false
if i == nargtypes && isvarargtype(atyp)
wasva = true
atyp = unwrapva(atyp)
end
for i = (tail_index+1):nargs
cache_argtypes[i] = lastatype
atyp = unwraptv(atyp)
if issingletontype(atyp)
# replace singleton types with their equivalent Const object
atyp = Const(atyp.instance)
elseif isconstType(atyp)
atyp = Const(atyp.parameters[1])
else
atyp = elim_free_typevars(rewrap_unionall(atyp, specTypes))
end
else
@assert nargs == 0 "invalid specialization of method" # wrong number of arguments
mi_argtypes[i] = atyp
if wasva
lastatype = atyp
mi_argtypes[end] = Vararg{widenconst(atyp)}
end
end
for i = (tail_index+1):(nargs-1)
mi_argtypes[i] = lastatype
end
return cache_argtypes
return mi_argtypes
end

# eliminate free `TypeVar`s in order to make the life much easier down the road:
Expand All @@ -184,7 +181,6 @@ function cache_lookup(𝕃::AbstractLattice, mi::MethodInstance, given_argtypes:
cache::Vector{InferenceResult})
method = mi.def::Method
nargtypes = length(given_argtypes)
@assert nargtypes == Int(method.nargs) "invalid `given_argtypes` for `mi`"
for cached_result in cache
cached_result.linfo === mi || @goto next_cache
cache_argtypes = cached_result.argtypes
Expand Down
10 changes: 6 additions & 4 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,9 @@ mutable struct InferenceState
bb_vartables = Union{Nothing,VarTable}[ nothing for i = 1:length(cfg.blocks) ]
bb_vartable1 = bb_vartables[1] = VarTable(undef, nslots)
argtypes = result.argtypes

argtypes = va_process_argtypes(typeinf_lattice(interp), argtypes, src.nargs, src.isva)

nargtypes = length(argtypes)
for i = 1:nslots
argtyp = (i > nargtypes) ? Bottom : argtypes[i]
Expand Down Expand Up @@ -766,10 +769,9 @@ function print_callstack(sv::InferenceState)
end

function narguments(sv::InferenceState, include_va::Bool=true)
def = sv.linfo.def
nargs = length(sv.result.argtypes)
nargs = sv.src.nargs
if !include_va
nargs -= isa(def, Method) && def.isva
nargs -= sv.src.isva
end
return nargs
end
Expand Down Expand Up @@ -831,7 +833,7 @@ function IRInterpretationState(interp::AbstractInterpreter,
end
method_info = MethodInfo(src)
ir = inflate_ir(src, mi)
argtypes = va_process_argtypes(optimizer_lattice(interp), argtypes, mi)
argtypes = va_process_argtypes(optimizer_lattice(interp), argtypes, src.nargs, src.isva)
return IRInterpretationState(interp, method_info, ir, mi, argtypes, world,
codeinst.min_world, codeinst.max_world)
end
Expand Down