函数式元组

lua-users home
wiki

本文描述了一种用纯函数来表示元组的新颖设计模式。

元组是对象的不可变序列。它们存在于许多编程语言中,包括 [Python][Erlang],以及几乎所有的 [函数式语言]。Lua 字符串是一种特殊的元组,其元素仅限于单个字符。

由于元组是不可变的,它们可以共享而无需复制。另一方面,它们也不能被修改;对元组的任何修改都必须创建一个新的元组。

为了说明这个概念,我们将三维空间中的一个点实现为一个元组 <x, y, z>

Lua 提供了多种实现元组的方法;以下是改编自优秀教科书 [Structure and Interpretation of Computer Programs] 中实现的一种方法。

遵循 Abelson 和 Sussman 的思想,我们将元组表示为一个单参数的函数;该参数本身必须是一个函数;我们可以将其视为方法或槽访问器。

首先,我们需要一个构造函数和一些成员选择器

function Point(_x, _y, _z)
  return function(fn) return fn(_x, _y, _z) end
end

function x(_x, _y, _z) return _x end
function y(_x, _y, _z) return _y end
function z(_x, _y, _z) return _z end

这里发生了什么?Point 接收三个参数,即点的坐标,并返回一个函数;就这些而言,我们将返回值视为不透明的。用一个函数调用 Point 会将该函数作为参数传递给构成元组的对象;选择器只返回其中一个并忽略其他。

> p1 = Point(1, 2, 3)
> =p1(x)
1
> =p1(z)
3

然而,我们不限于选择器;我们可以编写任何任意函数

function vlength(_x, _y, _z)
  return math.sqrt(_x * _x + _y * _y + _z * _z)
end

> =p1(vlength)
3.7416573867739

现在,虽然我们不能修改元组,但我们可以编写函数来创建一个具有特定修改的新元组(这类似于标准 Lua 库中的 string.gsub

function subst_x(_x)
  return function(_, _y, _z) return Point(_x, _y, _z) end
end
function subst_y(_y)
  return function(_x, _, _z) return Point(_x, _y, _z) end
end
function subst_z(_z)
  return function(_x, _y, _) return Point(_x, _y, _z) end
end

gsub 一样,这些函数不会影响原始点的内容

> p2 = p1(subst_x(42))
> =p1(x)
1
> =p2(x)
42

值得注意的是,我们可以使用任何接受任意数量参数的函数

> p2(print)
42      2       3

同样,我们可以编写组合两个点的函数

function vadd(v2)
  return function(_x, _y, _z)
    return Point(_x + v2(x), _y + v2(y), _z + v2(z))
  end
end

function vsubtract(v2)
  return function(_x, _y, _z)
    return Point(_x - v2(x), _y - v2(y), _z - v2(z))
  end
end

> =p1(vadd(p1))(print)
2       4       6

仔细检查 vaddvsubtract(以及各种 substitute 函数)会发现它们实际上是在创建一个带有闭包(它们的原始参数)的临时函数。但是,这些函数不必是临时的。事实上,我们可能希望多次使用特定的转换,在这种情况下,我们可以将其保存下来。

> shiftDiagonally = vadd(Point(1, 1, 1))
> p2(print)
42      2       3
> p2(shiftDiagonally)(print)
43      3       4
> p2(shiftDiagonally)(shiftDiagonally)(print)
44      4       5

这可能会促使我们重新审视 vadd 的定义,以避免创建然后解构参数

function subtractPoint(x, y, z)
  return function(_x, _y, _z) return _x - x, _y - y, _z - z end
end

function addPoint(x, y, z)
  return function(_x, _y, _z) return _x + x, _y + y, _z + z end
end

既然如此,让我们再添加几个转换

function scaleBy(q)
  return function(_x, _y, _z) return q * _x, q * _y, q * _z end
end

function rotateBy(theta)
  local sintheta, costheta = math.sin(theta), math.cos(theta)
  return function(_x, _y, _z)
    return _x * costheta - _y * sintheta, _x * sintheta + _y * costheta, _z
  end
end

请注意,在 rotateBy 中,我们预先计算了正弦和余弦,以避免每次应用函数时都要调用数学库。

现在这些函数不返回 Point;它们只返回构成 Point 的值。要使用它们,我们必须显式创建点

> p3 = Point(p1(scaleBy(10)))
> p3(print)
10      20      30

这有点繁琐。但是正如我们将看到的,它有它的优点。

但首先,让我们再次看看 addPoint。如果我们心中有一个转换,那很好,但如果我们想通过一个特定的点来移动呢?p1(addPoint(p2)) 显然行不通。然而,答案非常简单

> centre = Point(0.5, 0.5, 0.5)
> -- This doesn't work
> =p1(subtractPoint(centre))
stdin:2: attempt to perform arithmetic on a function value
stack traceback:
        stdin:2: in function <stdin:2>
        (tail call): ?
        (tail call): ?
        [C]: ?
> -- But this works just fine:
> =p1(centre(subtractPoint))
0.5     1.5     2.5

此外,这些新函数可以被组合;我们可以有效地将一系列转换创建一个单一的原始操作。

-- A complex transformation
function transform(centre, expand, theta)
  local shift = centre(subtractPoint)
  local exp = scaleBy(expand)
  local rot = rotateBy(theta)
  local unshift = centre(addPoint)
  return function(_x, _y, _z)
    return unshift(exp(rot(shift(_x, _y, _z))))
  end
end

> xform = transform(centre, 10, math.pi / 4)
> =p1(xform)
-6.5710678118655        14.642135623731 25.5

这带来的一个巨大好处是,一旦创建了 xform,它就可以执行而无需创建任何堆对象。所有内存消耗都在栈上。当然,这有点不诚实——创建元组(函数闭包和三个上值)以及创建单个转换器会进行大量的内存分配。

此外,我们还没有处理一些重要的语法问题,例如如何让普通程序员实际使用这些元组。

--RiciLake

泛化到任意大小 N

为了使上述方案适用于任意大小的元组,我们可以使用 代码生成,如下所示——DavidManura

function all(n, ...) return ... end     -- return all elements in tuple
function size(n) return n end           -- return size of tuple
function first(n,e, ...) return e end     -- return first element in tuple
function second(n,_,e, ...) return e end  -- return second element in tuple
function third(n,_,_,e, ...) return e end -- return third element in tuple
local nthf = {first, second, third}
function nth(n)
  return nthf[n] or function(...) return select(n+1, ...) end
end

local function make_tuple_equals(n)
  local ta, tb, te = {}, {}, {}
  for i=1,n do
    ta[#ta+1] = "a" .. i
    tb[#tb+1] = "b" .. i
    te[#te+1] = "a" .. i .. "==b" .. i
  end
  local alist = table.concat(ta, ",")
  if alist ~= "" then alist = "," .. alist end
  local blist = table.concat(tb, ",")
  if blist ~= "" then blist = "," .. blist end
  local elist = table.concat(te, " and ")
  if elist ~= "" then elist = "and " .. elist end
  local s = [[
    local t, n1 %s = ...
    local f = function(n2 %s)
      return n1==n2 %s
    end
    return t(f)
  ]]
  s = string.format(s, alist, blist, elist)
  return assert(loadstring(s))
end

local cache = {}
function equals(t)
  local n = t(size)
  local f = cache[n]; if not f then
    f = make_tuple_equals(n)
    cache[n] = f
  end
  return function(...) return f(t, ...) end
end

local function equals2(t1, t2)
  return t1(equals(t2))
end

local ops = {
  ['#'] = size,
  ['*'] = all,
}
local ops2 = {
  ["number"]   = function(x) return nth(x) end,
  ["function"] = function(x) return x end,
  ["string"]   = function(x) return ops[x] end
}

local function make_tuple_constructor(n)
  local ts = {}
  for i=1,n do ts[#ts+1] = "a" .. i end
  local slist = table.concat(ts, ",")
  local c = slist ~= "" and "," or ""
  local s =
    "local ops2 = ... " ..
    "return function(" .. slist .. ") " ..
    "  return function(f) " ..
     "    return (ops2[type(f)](f))(" ..
                 n .. c .. slist .. ") end " ..
    "end"
  return assert(loadstring(s))(ops2)
end

local cache = {}
function tuple(...)
  local n = select('#', ...)
  local f = cache[n]; if not f then
    f = make_tuple_constructor(n)
    cache[n] = f
  end
  return f(...)
end

测试

-- test suite
local t = tuple(1,nil,2,nil)
;(function(a,b,c,d) assert(a==1 and b==nil and c==2 and d==nil) end)(t(all))
;(function(a,b,c,d) assert(a==1 and b==nil and c==2 and d==nil) end)(t '*')
assert(t(size) == 4)
assert(t '#' == 4)
assert(t(nth(1)) == 1 and t(nth(2)) == nil and t(nth(3)) == 2 and
       t(nth(4)) == nil)
assert(t(1) == 1 and t(2) == nil and t(3) == 2 and t(4) == nil)
assert(t(first) == 1 and t(second) == nil and t(third) == 2)
local t = tuple(3,4,5,6)
assert(t(nth(1)) == 3 and t(nth(2)) == 4 and t(nth(3)) == 5 and
       t(nth(4)) == 6)
assert(t(first) == 3 and t(second) == 4 and t(third) == 5)
assert(tuple()(size) == 0 and tuple(3)(size) == 1 and tuple(3,4)(size) == 2)
assert(tuple(nil)(size) == 1)
assert(tuple(3,nil,5)(equals(tuple(3,nil,5))))
assert(not tuple(3,nil,5)(equals(tuple(3,1,5))))
assert(not tuple(3,nil)(equals(tuple(3,nil,5))))
assert(not tuple(3,5,nil)(equals(tuple(3,5))))
assert(tuple()(equals(tuple())))
assert(tuple(nil)(equals(tuple(nil))))
assert(tuple(1)(equals(tuple(1))))
assert(not tuple(1)(equals(tuple())))
assert(not tuple()(equals(tuple(1))))
assert(equals2(tuple(3,nil,5), tuple(3,nil,5)))
assert(not equals2(tuple(3,nil,5), tuple(3,1,5)))


-- example
function trace(f)
  return function(...)
    print("+function")
    local t = tuple(f(...))
    print("-function")
    return t(all)
  end
end
local test = trace(function (a,b,c)
  print("test",a+b+c)
end)
test(2,3,4)
--[[OUTPUT:
+function
test    9
-function
]]

评论

我认为这个页面有误导性。这些不是元组。元组只有在按值比较时才有意义,这样它们才能被索引(即,可以作为表键)。没有这个属性,它们并不比表好。请参阅 [1],了解使用内部索引树的 n 元组实现。--CosminApreutesei

另请参阅


RecentChanges · preferences
编辑 · 历史
最后编辑于 2014 年 9 月 12 日 上午 10:09 GMT (差异)