monkey.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. package monkey
  2. import (
  3. "fmt"
  4. "reflect"
  5. "sync"
  6. "unsafe"
  7. )
  8. // patch is an applied patch
  9. // needed to undo a patch
  10. type patch struct {
  11. originalBytes []byte
  12. replacement *reflect.Value
  13. }
  14. var (
  15. lock = sync.Mutex{}
  16. patches = make(map[reflect.Value]patch)
  17. )
  18. type value struct {
  19. _ uintptr
  20. ptr unsafe.Pointer
  21. }
  22. func getPtr(v reflect.Value) unsafe.Pointer {
  23. return (*value)(unsafe.Pointer(&v)).ptr
  24. }
  25. type PatchGuard struct {
  26. target reflect.Value
  27. replacement reflect.Value
  28. }
  29. func (g *PatchGuard) Unpatch() {
  30. unpatchValue(g.target)
  31. }
  32. func (g *PatchGuard) Restore() {
  33. patchValue(g.target, g.replacement)
  34. }
  35. // Patch replaces a function with another
  36. func Patch(target, replacement interface{}) *PatchGuard {
  37. t := reflect.ValueOf(target)
  38. r := reflect.ValueOf(replacement)
  39. patchValue(t, r)
  40. return &PatchGuard{t, r}
  41. }
  42. // PatchInstanceMethod replaces an instance method methodName for the type target with replacement
  43. // Replacement should expect the receiver (of type target) as the first argument
  44. func PatchInstanceMethod(target reflect.Type, methodName string, replacement interface{}) *PatchGuard {
  45. m, ok := target.MethodByName(methodName)
  46. if !ok {
  47. panic(fmt.Sprintf("unknown method %s", methodName))
  48. }
  49. r := reflect.ValueOf(replacement)
  50. patchValue(m.Func, r)
  51. return &PatchGuard{m.Func, r}
  52. }
  53. func patchValue(target, replacement reflect.Value) {
  54. lock.Lock()
  55. defer lock.Unlock()
  56. if target.Kind() != reflect.Func {
  57. panic("target has to be a Func")
  58. }
  59. if replacement.Kind() != reflect.Func {
  60. panic("replacement has to be a Func")
  61. }
  62. if target.Type() != replacement.Type() {
  63. panic(fmt.Sprintf("target and replacement have to have the same type %s != %s", target.Type(), replacement.Type()))
  64. }
  65. if patch, ok := patches[target]; ok {
  66. unpatch(target, patch)
  67. }
  68. bytes := replaceFunction(*(*uintptr)(getPtr(target)), uintptr(getPtr(replacement)))
  69. patches[target] = patch{bytes, &replacement}
  70. }
  71. // Unpatch removes any monkey patches on target
  72. // returns whether target was patched in the first place
  73. func Unpatch(target interface{}) bool {
  74. return unpatchValue(reflect.ValueOf(target))
  75. }
  76. // UnpatchInstanceMethod removes the patch on methodName of the target
  77. // returns whether it was patched in the first place
  78. func UnpatchInstanceMethod(target reflect.Type, methodName string) bool {
  79. m, ok := target.MethodByName(methodName)
  80. if !ok {
  81. panic(fmt.Sprintf("unknown method %s", methodName))
  82. }
  83. return unpatchValue(m.Func)
  84. }
  85. // UnpatchAll removes all applied monkeypatches
  86. func UnpatchAll() {
  87. lock.Lock()
  88. defer lock.Unlock()
  89. for target, p := range patches {
  90. unpatch(target, p)
  91. delete(patches, target)
  92. }
  93. }
  94. // Unpatch removes a monkeypatch from the specified function
  95. // returns whether the function was patched in the first place
  96. func unpatchValue(target reflect.Value) bool {
  97. lock.Lock()
  98. defer lock.Unlock()
  99. patch, ok := patches[target]
  100. if !ok {
  101. return false
  102. }
  103. unpatch(target, patch)
  104. delete(patches, target)
  105. return true
  106. }
  107. func unpatch(target reflect.Value, p patch) {
  108. copyToLocation(*(*uintptr)(getPtr(target)), p.originalBytes)
  109. }