123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133 |
- package monkey
- import (
- "fmt"
- "reflect"
- "sync"
- "unsafe"
- )
- // patch is an applied patch
- // needed to undo a patch
- type patch struct {
- originalBytes []byte
- replacement *reflect.Value
- }
- var (
- lock = sync.Mutex{}
- patches = make(map[reflect.Value]patch)
- )
- type value struct {
- _ uintptr
- ptr unsafe.Pointer
- }
- func getPtr(v reflect.Value) unsafe.Pointer {
- return (*value)(unsafe.Pointer(&v)).ptr
- }
- type PatchGuard struct {
- target reflect.Value
- replacement reflect.Value
- }
- func (g *PatchGuard) Unpatch() {
- unpatchValue(g.target)
- }
- func (g *PatchGuard) Restore() {
- patchValue(g.target, g.replacement)
- }
- // Patch replaces a function with another
- func Patch(target, replacement interface{}) *PatchGuard {
- t := reflect.ValueOf(target)
- r := reflect.ValueOf(replacement)
- patchValue(t, r)
- return &PatchGuard{t, r}
- }
- // PatchInstanceMethod replaces an instance method methodName for the type target with replacement
- // Replacement should expect the receiver (of type target) as the first argument
- func PatchInstanceMethod(target reflect.Type, methodName string, replacement interface{}) *PatchGuard {
- m, ok := target.MethodByName(methodName)
- if !ok {
- panic(fmt.Sprintf("unknown method %s", methodName))
- }
- r := reflect.ValueOf(replacement)
- patchValue(m.Func, r)
- return &PatchGuard{m.Func, r}
- }
- func patchValue(target, replacement reflect.Value) {
- lock.Lock()
- defer lock.Unlock()
- if target.Kind() != reflect.Func {
- panic("target has to be a Func")
- }
- if replacement.Kind() != reflect.Func {
- panic("replacement has to be a Func")
- }
- if target.Type() != replacement.Type() {
- panic(fmt.Sprintf("target and replacement have to have the same type %s != %s", target.Type(), replacement.Type()))
- }
- if patch, ok := patches[target]; ok {
- unpatch(target, patch)
- }
- bytes := replaceFunction(*(*uintptr)(getPtr(target)), uintptr(getPtr(replacement)))
- patches[target] = patch{bytes, &replacement}
- }
- // Unpatch removes any monkey patches on target
- // returns whether target was patched in the first place
- func Unpatch(target interface{}) bool {
- return unpatchValue(reflect.ValueOf(target))
- }
- // UnpatchInstanceMethod removes the patch on methodName of the target
- // returns whether it was patched in the first place
- func UnpatchInstanceMethod(target reflect.Type, methodName string) bool {
- m, ok := target.MethodByName(methodName)
- if !ok {
- panic(fmt.Sprintf("unknown method %s", methodName))
- }
- return unpatchValue(m.Func)
- }
- // UnpatchAll removes all applied monkeypatches
- func UnpatchAll() {
- lock.Lock()
- defer lock.Unlock()
- for target, p := range patches {
- unpatch(target, p)
- delete(patches, target)
- }
- }
- // Unpatch removes a monkeypatch from the specified function
- // returns whether the function was patched in the first place
- func unpatchValue(target reflect.Value) bool {
- lock.Lock()
- defer lock.Unlock()
- patch, ok := patches[target]
- if !ok {
- return false
- }
- unpatch(target, patch)
- delete(patches, target)
- return true
- }
- func unpatch(target reflect.Value, p patch) {
- copyToLocation(*(*uintptr)(getPtr(target)), p.originalBytes)
- }
|