using System;
using System.Collections.Generic;
using dnlib.DotNet;
namespace dnSpy.Contracts.Decompiler {
///
/// Resolves generic arguments
///
public struct GenericArgumentResolver {
IList typeGenArgs;
IList methodGenArgs;
RecursionCounter recursionCounter;
GenericArgumentResolver(IList? typeGenArgs, IList? methodGenArgs) {
this.typeGenArgs = typeGenArgs ?? Array.Empty();
this.methodGenArgs = methodGenArgs ?? Array.Empty();
recursionCounter = new RecursionCounter();
}
///
/// Resolves the type signature with the specified generic arguments.
///
/// The type signature.
/// The type generic arguments.
/// The method generic arguments.
/// Resolved type signature.
/// No generic arguments to resolve.
public static TypeSig? Resolve(TypeSig? typeSig, IList? typeGenArgs, IList? methodGenArgs) {
if (typeSig is null)
return typeSig;
if ((typeGenArgs is null || typeGenArgs.Count == 0) && (methodGenArgs is null || methodGenArgs.Count == 0))
return typeSig;
var resolver = new GenericArgumentResolver(typeGenArgs, methodGenArgs);
return resolver.ResolveGenericArgs(typeSig);
}
///
/// Resolves the method signature with the specified generic arguments.
///
/// The method signature.
/// The type generic arguments.
/// The method generic arguments.
/// Resolved method signature.
/// No generic arguments to resolve.
public static MethodBaseSig? Resolve(MethodBaseSig? methodSig, IList? typeGenArgs, IList? methodGenArgs) {
if (methodSig is null)
return null;
if ((typeGenArgs is null || typeGenArgs.Count == 0) && (methodGenArgs is null || methodGenArgs.Count == 0))
return methodSig;
var resolver = new GenericArgumentResolver(typeGenArgs, methodGenArgs);
return resolver.ResolveGenericArgs(methodSig);
}
bool ReplaceGenericArg(ref TypeSig typeSig) {
if (typeSig is GenericMVar genericMVar) {
var newSig = Read(methodGenArgs, genericMVar.Number);
if (newSig is not null) {
typeSig = newSig;
return true;
}
return false;
}
if (typeSig is GenericVar genericVar) {
var newSig = Read(typeGenArgs, genericVar.Number);
if (newSig is not null) {
typeSig = newSig;
return true;
}
return false;
}
return false;
}
static TypeSig? Read(IList sigs, uint index) {
if (index < (uint)sigs.Count)
return sigs[(int)index];
return null;
}
MethodSig? ResolveGenericArgs(MethodBaseSig sig) {
if (sig is null)
return null;
if (!recursionCounter.Increment())
return null;
MethodSig result = ResolveGenericArgs(new MethodSig(sig.CallingConvention), sig);
recursionCounter.Decrement();
return result;
}
MethodSig ResolveGenericArgs(MethodSig sig, MethodBaseSig old) {
sig.RetType = ResolveGenericArgs(old.RetType);
foreach (var p in old.Params)
sig.Params.Add(ResolveGenericArgs(p));
sig.GenParamCount = old.GenParamCount;
if (sig.ParamsAfterSentinel is not null) {
foreach (var p in old.ParamsAfterSentinel)
sig.ParamsAfterSentinel.Add(ResolveGenericArgs(p));
}
return sig;
}
TypeSig? ResolveGenericArgs(TypeSig typeSig) {
if (typeSig is null)
return null;
if (!recursionCounter.Increment())
return null;
if (ReplaceGenericArg(ref typeSig)) {
recursionCounter.Decrement();
return typeSig;
}
TypeSig result;
switch (typeSig.ElementType) {
case ElementType.Ptr:
result = new PtrSig(ResolveGenericArgs(typeSig.Next));
break;
case ElementType.ByRef:
result = new ByRefSig(ResolveGenericArgs(typeSig.Next));
break;
case ElementType.Var:
result = new GenericVar(((GenericVar)typeSig).Number, ((GenericVar)typeSig).OwnerType);
break;
case ElementType.ValueArray:
result = new ValueArraySig(ResolveGenericArgs(typeSig.Next), ((ValueArraySig)typeSig).Size);
break;
case ElementType.SZArray:
result = new SZArraySig(ResolveGenericArgs(typeSig.Next));
break;
case ElementType.MVar:
result = new GenericMVar(((GenericMVar)typeSig).Number, ((GenericMVar)typeSig).OwnerMethod);
break;
case ElementType.CModReqd:
result = new CModReqdSig(((ModifierSig)typeSig).Modifier, ResolveGenericArgs(typeSig.Next));
break;
case ElementType.CModOpt:
result = new CModOptSig(((ModifierSig)typeSig).Modifier, ResolveGenericArgs(typeSig.Next));
break;
case ElementType.Module:
result = new ModuleSig(((ModuleSig)typeSig).Index, ResolveGenericArgs(typeSig.Next));
break;
case ElementType.Pinned:
result = new PinnedSig(ResolveGenericArgs(typeSig.Next));
break;
case ElementType.FnPtr:
result = new FnPtrSig(ResolveGenericArgs(((FnPtrSig)typeSig).MethodSig));
break;
case ElementType.Array:
ArraySig arraySig = (ArraySig)typeSig;
List sizes = new List(arraySig.Sizes);
List lbounds = new List(arraySig.LowerBounds);
result = new ArraySig(ResolveGenericArgs(typeSig.Next), arraySig.Rank, sizes, lbounds);
break;
case ElementType.GenericInst:
GenericInstSig gis = (GenericInstSig)typeSig;
List genArgs = new List(gis.GenericArguments.Count);
foreach (TypeSig ga in gis.GenericArguments) {
genArgs.Add(ResolveGenericArgs(ga));
}
result = new GenericInstSig(ResolveGenericArgs(gis.GenericType as TypeSig) as ClassOrValueTypeSig, genArgs);
break;
default:
result = typeSig;
break;
}
recursionCounter.Decrement();
return result;
}
CallingConventionSig? ResolveGenericArgs(CallingConventionSig sig) {
if (!recursionCounter.Increment())
return null;
CallingConventionSig? result;
MethodSig? msig;
FieldSig? fsig;
LocalSig? lsig;
PropertySig? psig;
GenericInstMethodSig? gsig;
if ((msig = sig as MethodSig) is not null)
result = ResolveGenericArgs(msig);
else if ((fsig = sig as FieldSig) is not null)
result = ResolveGenericArgs(fsig);
else if ((lsig = sig as LocalSig) is not null)
result = ResolveGenericArgs(lsig);
else if ((psig = sig as PropertySig) is not null)
result = ResolveGenericArgs(psig);
else if ((gsig = sig as GenericInstMethodSig) is not null)
result = ResolveGenericArgs(gsig);
else
result = null;
recursionCounter.Decrement();
return result;
}
MethodSig ResolveGenericArgs(MethodSig sig) {
var msig = ResolveGenericArgs2(new MethodSig(), sig);
msig.OriginalToken = sig.OriginalToken;
return msig;
}
PropertySig ResolveGenericArgs(PropertySig sig) => ResolveGenericArgs2(new PropertySig(), sig);
T ResolveGenericArgs2(T outSig, T inSig) where T : MethodBaseSig {
outSig.RetType = ResolveGenericArgs(inSig.RetType);
outSig.GenParamCount = inSig.GenParamCount;
UpdateSigList(outSig.Params, inSig.Params);
if (inSig.ParamsAfterSentinel is not null) {
outSig.ParamsAfterSentinel = new List(inSig.ParamsAfterSentinel.Count);
UpdateSigList(outSig.ParamsAfterSentinel, inSig.ParamsAfterSentinel);
}
return outSig;
}
void UpdateSigList(IList inList, IList outList) {
foreach (var arg in outList)
inList.Add(ResolveGenericArgs(arg));
}
FieldSig ResolveGenericArgs(FieldSig sig) => new FieldSig(ResolveGenericArgs(sig.Type));
LocalSig ResolveGenericArgs(LocalSig sig) {
var lsig = new LocalSig();
UpdateSigList(lsig.Locals, sig.Locals);
return lsig;
}
GenericInstMethodSig ResolveGenericArgs(GenericInstMethodSig sig) {
var gsig = new GenericInstMethodSig();
UpdateSigList(gsig.GenericArguments, sig.GenericArguments);
return gsig;
}
}
struct RecursionCounter {
const int MAX_RECURSION_COUNT = 100;
int counter;
public bool Increment() {
if (counter >= MAX_RECURSION_COUNT)
return false;
counter++;
return true;
}
public void Decrement() {
#if DEBUG
if (counter <= 0)
throw new InvalidOperationException("recursionCounter <= 0");
#endif
counter--;
}
}
}