柯里化 Lua

lua-users home
wiki

柯里化 被维基百科[1] 定义如下

"在计算机科学中,柯里化是一种将接受多个参数的函数转换为接受单个参数(原始函数的第一个参数)并返回一个新函数的技术,该新函数接受剩余参数并返回结果"

您可以在所有支持函数作为一等公民的语言中实现柯里化函数。例如,有一个关于柯里化 JavaScript 的小[教程]

这是一个柯里化函数的小型 Lua 示例

function sum(number) 
  return function(anothernumber) 
    return number + anothernumber
  end
end

local f = sum(5)
print(f(3)) --> 8

-- WalterCruz

这是另一个由 [GavinWraith] 贡献的示例,它接受以 "()" 结尾的可变数量的参数

function addup(x)
  local sum = 0
  local function f(n)
    if type(n) == "number" then
      sum = sum + n
      return f
    else
      return sum
    end
  end
  return f(x)
end

print(addup (1) (2) (3) ())  --> 6
print(addup (4) (5) (6) ())  --> 15

虽然这些预柯里化函数很有用,但我们真正想做的是创建一个通用函数,它可以对任何其他函数执行柯里化操作。为此,我们需要意识到函数可以由 "高阶函数" 操作 - 一个以函数作为参数的函数。以下柯里化函数就是一个例子,它柯里化了一个 2 参数函数

function curry(f)
    return function (x) return function (y) return f(x,y) end end
end

powcurry = curry(math.pow)
powcurry (2) (4) --> 16
pow2 = powcurry(2)
pow2(3) --> 8
pow2(4) --> 16
pow2(8) --> 256

从柯里化 2 个参数到柯里化 'n' 个参数要复杂一些。我们需要存储不确定的部分应用数量,不幸的是,Lua 无法知道函数需要多少个参数;Lua 函数可以成功接收任意数量的参数,无论过多还是过少。因此,有必要告诉柯里化函数在将收集到的参数应用于原始函数之前,接受多少个单参数调用。

(此代码可从 http://tinylittlelife.org/?p=249 免费获取,并包含关于如何解决此问题的完整讨论。)

-- curry(func, num_args) : take a function requiring a tuple for num_args arguments
--                         and turn it into a series of 1-argument functions
-- e.g.: you have a function dosomething(a, b, c)
-- curried_dosomething = curry(dosomething, 3) -- we want to curry 3 arguments
-- curried_dosomething (a1) (b1) (c1) -- returns the result of dosomething(a1, b1, c1)
-- partial_dosomething1 = curried_dosomething (a_value) -- returns a function
-- partial_dosomething2 = partial_dosomething1 (b_value) -- returns a function
-- partial_dosomething2 (c_value) -- returns the result of dosomething(a_value, b_value, c_value)
function curry(func, num_args)

   -- currying 2-argument functions seems to be the most popular application
   num_args = num_args or 2

   -- no sense currying for 1 arg or less
   if num_args <= 1 then return func end

   -- helper takes an argtrace function, and number of arguments remaining to be applied
   local function curry_h(argtrace, n)
      if 0 == n then
	 -- kick off argtrace, reverse argument list, and call the original function
         return func(reverse(argtrace()))
      else
         -- "push" argument (by building a wrapper function) and decrement n
         return function (onearg)
                   return curry_h(function () return onearg, argtrace() end, n - 1)
                end
      end
   end  
   
   -- push the terminal case of argtrace into the function first
   return curry_h(function () return end, num_args)

end

-- reverse(...) : take some tuple and return a tuple of elements in reverse order
--                  
-- e.g. "reverse(1,2,3)" returns 3,2,1
function reverse(...)

   --reverse args by building a function to do it, similar to the unpack() example
   local function reverse_h(acc, v, ...)
      if 0 == select('#', ...) then
	 return v, acc()
      else
         return reverse_h(function () return v, acc() end, ...)
      end
   end  

   -- initial acc is the end of the list
   return reverse_h(function () return end, ...)
end

以上代码与 Lua 5.1 兼容。

由于 Lua 5.2(或 LuaJIT 2.0)提供了一个高级的 debug.getinfo 函数,它可以让我们知道函数需要多少个参数,我们可以创建一个将柯里化和部分应用技术结合在一起的实用函数。以下是代码

function curry(func, num_args)
  num_args = num_args or debug.getinfo(func, "u").nparams
  if num_args < 2 then return func end
  local function helper(argtrace, n)
    if n < 1 then
      return func(unpack(flatten(argtrace)))
    else
      return function (...)
        return helper({argtrace, ...}, n - select("#", ...))
      end
    end
  end
  return helper({}, num_args)
end

function flatten(t)
  local ret = {}
  for _, v in ipairs(t) do
    if type(v) == 'table' then
      for _, fv in ipairs(flatten(v)) do
        ret[#ret + 1] = fv
      end
    else
      ret[#ret + 1] = v
    end
  end
  return ret
end

function multiplyAndAdd (a, b, c) return a * b + c end

curried_multiplyAndAdd = curry(multiplyAndAdd)

multiplyBySevenAndAdd = curried_multiplyAndAdd(7)

multiplySevenByEightAndAdd_v1 = multiplyBySevenAndAdd(8)
multiplySevenByEightAndAdd_v2 = curried_multiplyAndAdd(7, 8)

assert(multiplyAndAdd(7, 8, 9) == multiplySevenByEightAndAdd_v1(9))
assert(multiplyAndAdd(7, 8, 9) == multiplySevenByEightAndAdd_v2(9))
assert(multiplyAndAdd(7, 8, 9) == multiplyBySevenAndAdd(8, 9))
assert(multiplyAndAdd(7, 8, 9) == curried_multiplyAndAdd(7, 8, 9))

另请参阅


最近更改 · 偏好设置
编辑 · 历史记录
最后编辑于 2014 年 3 月 27 日下午 2:39 GMT (差异)