package memoize import ( "fmt" "reflect" ) // fptr is a pointer to a function variable which will receive a // memoized wrapper around function impl. Impl must have 1 or more // arguments, all of which must be usable as map keys; and it must // have 1 or more return values. func Memoize(fptr, impl interface{}) { implType := reflect.TypeOf(impl) implValue := reflect.ValueOf(impl) if implType.Kind() != reflect.Func { panic(fmt.Sprintf("Not a function: %v", impl)) } if implType.NumIn() == 0 { panic(fmt.Sprintf("%v takes no inputs", impl)) } if implType.NumOut() == 0 { panic(fmt.Sprintf("%v gives no outputs", impl)) } if !reflect.PtrTo(implType).AssignableTo(reflect.TypeOf(fptr)) { panic(fmt.Sprintf("Can't assign %v to %v", impl, fptr)) } var resultTypes []reflect.Type for on := 0; on < implType.NumOut(); on++ { out := implType.Out(on) resultTypes = append(resultTypes, out) } mapTypes := make([]reflect.Type, implType.NumIn()) mapType := reflect.TypeOf([]reflect.Value{}) mapTypes[len(mapTypes)-1] = mapType for in := implType.NumIn() - 1; in >= 0; in-- { inType := implType.In(in) mapType = reflect.MapOf(inType, mapType) mapTypes[in] = mapType } m := reflect.MakeMap(mapTypes[0]) mem := func(args []reflect.Value) []reflect.Value { thisMap := m for an := 0; an < len(args)-1; an++ { v := thisMap.MapIndex(args[an]) if !v.IsValid() { v = reflect.MakeMap(mapTypes[an+1]) thisMap.SetMapIndex(args[an], v) } thisMap = v } an := len(args) - 1 v := thisMap.MapIndex(args[an]) var vs []reflect.Value if v.IsValid() { for i := 0; i < v.Len(); i++ { // v.Index() gives us a Value for // Value for int. We need a Value for // int. valval := v.Index(i) val := deVal(valval).(reflect.Value) vs = append(vs, val) } } else { vs = implValue.Call(args) thisMap.SetMapIndex(args[an], reflect.ValueOf(vs)) } return vs } typedMem := reflect.MakeFunc(implType, mem) reflect.ValueOf(fptr).Elem().Set(typedMem) } func deVal(val reflect.Value) interface{} { var result interface{} inner := func(v interface{}) { result = v } reflect.ValueOf(inner).Call([]reflect.Value{val}) return result }