Skip to content

Commit 83bd183

Browse files
committed
__resultRef implementation
1 parent 2a09e7a commit 83bd183

File tree

5 files changed

+208
-14
lines changed

5 files changed

+208
-14
lines changed

Harmony/Documentation/articles/patching-injections.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ Patches can use an argument called **`__instance`** to access the instance value
1414

1515
Patches can use an argument called **`__result`** to access the returned value. The type must match the return type of the original or be assignable from it. For prefixes, as the original method hasn't run yet, the value of `__result` is the default for that type. For most reference types, that would be `null`. If you wish to **alter** the `__result`, you need to define it **by reference** like `ref string name`.
1616

17+
### __resultRef
18+
19+
Patches can use an argument called **`__resultRef`** to alter the "**ref return**" reference itself. The type must be `RefResult<T>` by reference, where `T` must match the return type of the original, without `ref` modifier. For example `ref RefResult<string> __resultRef`.
20+
1721
### __state
1822

1923
Patches can use an argument called **`__state`** to store information in the prefix method that can be accessed again in the postfix method. Think of it as a local variable. It can be any type and you are responsible to initialize its value in the prefix. **Note:** It only works if both patches are defined in the same class.

Harmony/Extras/RefResult.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
namespace HarmonyLib;
2+
3+
/// <summary>Delegate type for "ref return" injections</summary>
4+
/// <typeparam name="T">Return type of the original method, without ref modifier</typeparam>
5+
public delegate ref T RefResult<T>();

Harmony/Internal/MethodPatcher.cs

Lines changed: 88 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ internal class MethodPatcher
1515
const string ORIGINAL_METHOD_PARAM = "__originalMethod";
1616
const string ARGS_ARRAY_VAR = "__args";
1717
const string RESULT_VAR = "__result";
18+
const string RESULT_REF_VAR = "__resultRef";
1819
const string STATE_VAR = "__state";
1920
const string EXCEPTION_VAR = "__exception";
2021
const string RUN_ORIGINAL_VAR = "__runOriginal";
@@ -76,6 +77,19 @@ internal MethodInfo CreateReplacement(out Dictionary<int, CodeInstruction> final
7677
privateVars[RESULT_VAR] = resultVariable;
7778
}
7879

80+
if (fixes.Any(fix => fix.GetParameters().Any(p => p.Name == RESULT_REF_VAR)))
81+
{
82+
if(returnType.IsByRef)
83+
{
84+
var resultRefVariable = il.DeclareLocal(
85+
typeof(RefResult<>).MakeGenericType(returnType.GetElementType())
86+
);
87+
emitter.Emit(OpCodes.Ldnull);
88+
emitter.Emit(OpCodes.Stloc, resultRefVariable);
89+
privateVars[RESULT_REF_VAR] = resultRefVariable;
90+
}
91+
}
92+
7993
LocalBuilder argsArrayVariable = null;
8094
if (fixes.Any(fix => fix.GetParameters().Any(p => p.Name == ARGS_ARRAY_VAR)))
8195
{
@@ -432,10 +446,11 @@ bool EmitOriginalBaseMethod()
432446
return true;
433447
}
434448

435-
void EmitCallParameter(MethodInfo patch, Dictionary<string, LocalBuilder> variables, LocalBuilder runOriginalVariable, bool allowFirsParamPassthrough, out LocalBuilder tmpInstanceBoxingVar, out LocalBuilder tmpObjectVar, List<KeyValuePair<LocalBuilder, Type>> tmpBoxVars)
449+
void EmitCallParameter(MethodInfo patch, Dictionary<string, LocalBuilder> variables, LocalBuilder runOriginalVariable, bool allowFirsParamPassthrough, out LocalBuilder tmpInstanceBoxingVar, out LocalBuilder tmpObjectVar, out bool refResultUsed, List<KeyValuePair<LocalBuilder, Type>> tmpBoxVars)
436450
{
437451
tmpInstanceBoxingVar = null;
438452
tmpObjectVar = null;
453+
refResultUsed = false;
439454

440455
var isInstance = original.IsStatic is false;
441456
var originalParameters = original.GetParameters();
@@ -474,10 +489,10 @@ void EmitCallParameter(MethodInfo patch, Dictionary<string, LocalBuilder> variab
474489
else
475490
{
476491
var paramType = patchParam.ParameterType;
477-
492+
478493
var parameterIsRef = paramType.IsByRef;
479494
var parameterIsObject = paramType == typeof(object) || paramType == typeof(object).MakeByRefType();
480-
495+
481496
if (AccessTools.IsStruct(originalType))
482497
{
483498
if (parameterIsObject)
@@ -571,7 +586,6 @@ void EmitCallParameter(MethodInfo patch, Dictionary<string, LocalBuilder> variab
571586
// treat __result var special
572587
if (patchParam.Name == RESULT_VAR)
573588
{
574-
var returnType = AccessTools.GetReturnedType(original);
575589
if (returnType == typeof(void))
576590
throw new Exception($"Cannot get result from void method {original.FullDescription()}");
577591
var resultType = patchParam.ParameterType;
@@ -597,6 +611,25 @@ void EmitCallParameter(MethodInfo patch, Dictionary<string, LocalBuilder> variab
597611
continue;
598612
}
599613

614+
// treat __resultRef delegate special
615+
if (patchParam.Name == RESULT_REF_VAR)
616+
{
617+
if (!returnType.IsByRef)
618+
throw new Exception(
619+
$"Cannot use {RESULT_REF_VAR} with non-ref return type {returnType.FullName} of method {original.FullDescription()}");
620+
621+
var resultType = patchParam.ParameterType;
622+
var expectedTypeRef = typeof(RefResult<>).MakeGenericType(returnType.GetElementType()).MakeByRefType();
623+
if (resultType != expectedTypeRef)
624+
throw new Exception(
625+
$"Wrong type of {RESULT_REF_VAR} for method {original.FullDescription()}. Expected {expectedTypeRef.FullName}, got {resultType.FullName}");
626+
627+
emitter.Emit(OpCodes.Ldloca, variables[RESULT_REF_VAR]);
628+
629+
refResultUsed = true;
630+
continue;
631+
}
632+
600633
// any other declared variables
601634
if (variables.TryGetValue(patchParam.Name, out var localBuilder))
602635
{
@@ -763,7 +796,7 @@ void AddPrefixes(Dictionary<string, LocalBuilder> variables, LocalBuilder runOri
763796
}
764797

765798
var tmpBoxVars = new List<KeyValuePair<LocalBuilder, Type>>();
766-
EmitCallParameter(fix, variables, runOriginalVariable, false, out var tmpInstanceBoxingVar, out var tmpObjectVar, tmpBoxVars);
799+
EmitCallParameter(fix, variables, runOriginalVariable, false, out var tmpInstanceBoxingVar, out var tmpObjectVar, out var refResultUsed, tmpBoxVars);
767800
emitter.Emit(OpCodes.Call, fix);
768801
if (fix.GetParameters().Any(p => p.Name == ARGS_ARRAY_VAR))
769802
RestoreArgumentArray(variables);
@@ -774,7 +807,22 @@ void AddPrefixes(Dictionary<string, LocalBuilder> variables, LocalBuilder runOri
774807
emitter.Emit(OpCodes.Unbox_Any, original.DeclaringType);
775808
emitter.Emit(OpCodes.Stobj, original.DeclaringType);
776809
}
777-
if (tmpObjectVar != null)
810+
if (refResultUsed)
811+
{
812+
var label = il.DefineLabel();
813+
emitter.Emit(OpCodes.Ldloc, variables[RESULT_REF_VAR]);
814+
emitter.Emit(OpCodes.Brfalse_S, label);
815+
816+
emitter.Emit(OpCodes.Ldloc, variables[RESULT_REF_VAR]);
817+
emitter.Emit(OpCodes.Callvirt, AccessTools.Method(variables[RESULT_REF_VAR].LocalType, "Invoke"));
818+
emitter.Emit(OpCodes.Stloc, variables[RESULT_VAR]);
819+
emitter.Emit(OpCodes.Ldnull);
820+
emitter.Emit(OpCodes.Stloc, variables[RESULT_REF_VAR]);
821+
822+
emitter.MarkLabel(label);
823+
emitter.Emit(OpCodes.Nop);
824+
}
825+
else if (tmpObjectVar != null)
778826
{
779827
emitter.Emit(OpCodes.Ldloc, tmpObjectVar);
780828
emitter.Emit(OpCodes.Unbox_Any, AccessTools.GetReturnedType(original));
@@ -815,7 +863,7 @@ bool AddPostfixes(Dictionary<string, LocalBuilder> variables, LocalBuilder runOr
815863
// throw new Exception("Methods without body cannot have postfixes. Use a transpiler instead.");
816864

817865
var tmpBoxVars = new List<KeyValuePair<LocalBuilder, Type>>();
818-
EmitCallParameter(fix, variables, runOriginalVariable, true, out var tmpInstanceBoxingVar, out var tmpObjectVar, tmpBoxVars);
866+
EmitCallParameter(fix, variables, runOriginalVariable, true, out var tmpInstanceBoxingVar, out var tmpObjectVar, out var refResultUsed, tmpBoxVars);
819867
emitter.Emit(OpCodes.Call, fix);
820868
if (fix.GetParameters().Any(p => p.Name == ARGS_ARRAY_VAR))
821869
RestoreArgumentArray(variables);
@@ -826,7 +874,22 @@ bool AddPostfixes(Dictionary<string, LocalBuilder> variables, LocalBuilder runOr
826874
emitter.Emit(OpCodes.Unbox_Any, original.DeclaringType);
827875
emitter.Emit(OpCodes.Stobj, original.DeclaringType);
828876
}
829-
if (tmpObjectVar != null)
877+
if (refResultUsed)
878+
{
879+
var label = il.DefineLabel();
880+
emitter.Emit(OpCodes.Ldloc, variables[RESULT_REF_VAR]);
881+
emitter.Emit(OpCodes.Brfalse_S, label);
882+
883+
emitter.Emit(OpCodes.Ldloc, variables[RESULT_REF_VAR]);
884+
emitter.Emit(OpCodes.Callvirt, AccessTools.Method(variables[RESULT_REF_VAR].LocalType, "Invoke"));
885+
emitter.Emit(OpCodes.Stloc, variables[RESULT_VAR]);
886+
emitter.Emit(OpCodes.Ldnull);
887+
emitter.Emit(OpCodes.Stloc, variables[RESULT_REF_VAR]);
888+
889+
emitter.MarkLabel(label);
890+
emitter.Emit(OpCodes.Nop);
891+
}
892+
else if (tmpObjectVar != null)
830893
{
831894
emitter.Emit(OpCodes.Ldloc, tmpObjectVar);
832895
emitter.Emit(OpCodes.Unbox_Any, AccessTools.GetReturnedType(original));
@@ -871,7 +934,7 @@ bool AddFinalizers(Dictionary<string, LocalBuilder> variables, LocalBuilder runO
871934
emitter.MarkBlockBefore(new ExceptionBlock(ExceptionBlockType.BeginExceptionBlock), out var label);
872935

873936
var tmpBoxVars = new List<KeyValuePair<LocalBuilder, Type>>();
874-
EmitCallParameter(fix, variables, runOriginalVariable, false, out var tmpInstanceBoxingVar, out var tmpObjectVar, tmpBoxVars);
937+
EmitCallParameter(fix, variables, runOriginalVariable, false, out var tmpInstanceBoxingVar, out var tmpObjectVar, out var refResultUsed, tmpBoxVars);
875938
emitter.Emit(OpCodes.Call, fix);
876939
if (fix.GetParameters().Any(p => p.Name == ARGS_ARRAY_VAR))
877940
RestoreArgumentArray(variables);
@@ -882,7 +945,22 @@ bool AddFinalizers(Dictionary<string, LocalBuilder> variables, LocalBuilder runO
882945
emitter.Emit(OpCodes.Unbox_Any, original.DeclaringType);
883946
emitter.Emit(OpCodes.Stobj, original.DeclaringType);
884947
}
885-
if (tmpObjectVar != null)
948+
if (refResultUsed)
949+
{
950+
var label = il.DefineLabel();
951+
emitter.Emit(OpCodes.Ldloc, variables[RESULT_REF_VAR]);
952+
emitter.Emit(OpCodes.Brfalse_S, label);
953+
954+
emitter.Emit(OpCodes.Ldloc, variables[RESULT_REF_VAR]);
955+
emitter.Emit(OpCodes.Callvirt, AccessTools.Method(variables[RESULT_REF_VAR].LocalType, "Invoke"));
956+
emitter.Emit(OpCodes.Stloc, variables[RESULT_VAR]);
957+
emitter.Emit(OpCodes.Ldnull);
958+
emitter.Emit(OpCodes.Stloc, variables[RESULT_REF_VAR]);
959+
960+
emitter.MarkLabel(label);
961+
emitter.Emit(OpCodes.Nop);
962+
}
963+
else if (tmpObjectVar != null)
886964
{
887965
emitter.Emit(OpCodes.Ldloc, tmpObjectVar);
888966
emitter.Emit(OpCodes.Unbox_Any, AccessTools.GetReturnedType(original));

HarmonyTests/Patching/Assets/Specials.cs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,70 @@ public static void ResetTest()
2525

2626
// -----------------------------------------------------
2727

28+
public class ResultRefStruct
29+
{
30+
// ReSharper disable FieldCanBeMadeReadOnly.Global
31+
public static int[] numbersPrefix = [0, 0];
32+
public static int[] numbersPostfix = [0, 0];
33+
public static int[] numbersPostfixWithNull = [0];
34+
public static int[] numbersFinalizer = [0];
35+
public static int[] numbersMixed = [0, 0];
36+
// ReSharper restore FieldCanBeMadeReadOnly.Global
37+
38+
[MethodImpl(MethodImplOptions.NoInlining)]
39+
public ref int ToPrefix() => ref numbersPrefix[0];
40+
41+
[MethodImpl(MethodImplOptions.NoInlining)]
42+
public ref int ToPostfix() => ref numbersPostfix[0];
43+
44+
[MethodImpl(MethodImplOptions.NoInlining)]
45+
public ref int ToPostfixWithNull() => ref numbersPostfixWithNull[0];
46+
47+
[MethodImpl(MethodImplOptions.NoInlining)]
48+
public ref int ToFinalizer() => throw new Exception();
49+
50+
[MethodImpl(MethodImplOptions.NoInlining)]
51+
public ref int ToMixed() => ref numbersMixed[0];
52+
}
53+
54+
[HarmonyPatch(typeof(ResultRefStruct))]
55+
public class ResultRefStruct_Patch
56+
{
57+
[HarmonyPatch(nameof(ResultRefStruct.ToPrefix))]
58+
[HarmonyPrefix]
59+
public static bool Prefix(ref RefResult<int> __resultRef)
60+
{
61+
__resultRef = () => ref ResultRefStruct.numbersPrefix[1];
62+
return false;
63+
}
64+
65+
[HarmonyPatch(nameof(ResultRefStruct.ToPostfix))]
66+
[HarmonyPostfix]
67+
public static void Postfix(ref RefResult<int> __resultRef) => __resultRef = () => ref ResultRefStruct.numbersPostfix[1];
68+
69+
[HarmonyPatch(nameof(ResultRefStruct.ToPostfixWithNull))]
70+
[HarmonyPostfix]
71+
public static void PostfixWithNull(ref RefResult<int> __resultRef) => __resultRef = null;
72+
73+
[HarmonyPatch(nameof(ResultRefStruct.ToFinalizer))]
74+
[HarmonyFinalizer]
75+
public static Exception Finalizer(ref RefResult<int> __resultRef)
76+
{
77+
__resultRef = () => ref ResultRefStruct.numbersFinalizer[0];
78+
return null;
79+
}
80+
81+
[HarmonyPatch(nameof(ResultRefStruct.ToMixed))]
82+
[HarmonyPostfix]
83+
public static void PostfixMixed(ref int __result, ref RefResult<int> __resultRef)
84+
{
85+
__result = 42;
86+
__resultRef = () => ref ResultRefStruct.numbersMixed[1];
87+
}
88+
}
89+
90+
// -----------------------------------------------------
91+
2892
public class DeadEndCode
2993
{
3094
[MethodImpl(MethodImplOptions.NoInlining)]

HarmonyTests/Patching/Specials.cs

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,49 @@ public void Test_HttpWebRequestGetResponse()
5050
Assert.True(HttpWebRequestPatches.postfixCalled, "Postfix not called");
5151
}
5252

53+
[Test]
54+
public void Test_PatchResultRef()
55+
{
56+
ResultRefStruct.numbersPrefix = [0, 0];
57+
ResultRefStruct.numbersPostfix = [0, 0];
58+
ResultRefStruct.numbersPostfixWithNull = [0];
59+
ResultRefStruct.numbersFinalizer = [0];
60+
ResultRefStruct.numbersMixed = [0, 0];
61+
62+
var test = new ResultRefStruct();
63+
64+
var instance = new Harmony("result-ref-test");
65+
Assert.NotNull(instance);
66+
var processor = instance.CreateClassProcessor(typeof(ResultRefStruct_Patch));
67+
Assert.NotNull(processor, "processor");
68+
69+
test.ToPrefix() = 1;
70+
test.ToPostfix() = 2;
71+
test.ToPostfixWithNull() = 3;
72+
test.ToMixed() = 5;
73+
74+
Assert.AreEqual(new[] { 1, 0 }, ResultRefStruct.numbersPrefix);
75+
Assert.AreEqual(new[] { 2, 0 }, ResultRefStruct.numbersPostfix);
76+
Assert.AreEqual(new[] { 3 }, ResultRefStruct.numbersPostfixWithNull);
77+
Assert.Throws<Exception>(() => test.ToFinalizer(), "ToFinalizer method does not throw");
78+
Assert.AreEqual(new[] { 5, 0 }, ResultRefStruct.numbersMixed);
79+
80+
var replacements = processor.Patch();
81+
Assert.NotNull(replacements, "replacements");
82+
83+
test.ToPrefix() = -1;
84+
test.ToPostfix() = -2;
85+
test.ToPostfixWithNull() = -3;
86+
test.ToFinalizer() = -4;
87+
test.ToMixed() = -5;
88+
89+
Assert.AreEqual(new[] { 1, -1 }, ResultRefStruct.numbersPrefix);
90+
Assert.AreEqual(new[] { 2, -2 }, ResultRefStruct.numbersPostfix);
91+
Assert.AreEqual(new[] { -3 }, ResultRefStruct.numbersPostfixWithNull);
92+
Assert.AreEqual(new[] { -4 }, ResultRefStruct.numbersFinalizer);
93+
Assert.AreEqual(new[] { 42, -5 }, ResultRefStruct.numbersMixed);
94+
}
95+
5396
[Test]
5497
public void Test_Patch_ConcreteClass()
5598
{
@@ -327,7 +370,7 @@ public void Test_PatchExternalMethod()
327370
Assert.NotNull(patcher, "Patch processor");
328371
_ = patcher.Patch();
329372
}
330-
373+
331374
[Test]
332375
public void Test_PatchEventHandler()
333376
{
@@ -348,7 +391,7 @@ public void Test_PatchEventHandler()
348391
new EventHandlerTestClass().Run();
349392
Console.WriteLine($"### EventHandlerTestClass AFTER");
350393
}
351-
394+
352395
[Test]
353396
public void Test_PatchMarshalledClass()
354397
{
@@ -369,7 +412,7 @@ public void Test_PatchMarshalledClass()
369412
new MarshalledTestClass().Run();
370413
Console.WriteLine($"### MarshalledTestClass AFTER");
371414
}
372-
415+
373416
[Test]
374417
public void Test_MarshalledWithEventHandler1()
375418
{
@@ -390,7 +433,7 @@ public void Test_MarshalledWithEventHandler1()
390433
new MarshalledWithEventHandlerTest1Class().Run();
391434
Console.WriteLine($"### MarshalledWithEventHandlerTest1 AFTER");
392435
}
393-
436+
394437
[Test]
395438
public void Test_MarshalledWithEventHandler2()
396439
{

0 commit comments

Comments
 (0)