函数式元组

lua-users home
wiki

本文介绍了一种新颖的设计模式,用于仅使用函数来表达元组。

元组 是不可变的对象序列。它们存在于许多编程语言中,包括 [Python][Erlang],以及大多数 [函数式语言]。Lua 字符串是一种特殊的元组类型,其元素仅限于单个字符。

由于元组是不可变的,因此它们可以在不复制的情况下共享。另一方面,它们不能被修改;对元组的任何修改都必须创建一个新的元组。

为了解释这个概念,我们将三维空间中的一个点实现为一个元组 <x, y, z>

Lua 提供了许多实现元组的方法;以下是对优秀教科书 [计算机程序的结构和解释] 中找到的实现的改编。

遵循 Abelson 和 Sussman 的方法,我们将元组表示为一个函数,它接受一个参数;该参数本身必须是一个函数;我们可以将其视为方法或槽访问器。

首先,我们需要一个构造函数和一些成员选择器

function Point(_x, _y, _z)
  return function(fn) return fn(_x, _y, _z) end
end

function x(_x, _y, _z) return _x end
function y(_x, _y, _z) return _y end
function z(_x, _y, _z) return _z end

这里发生了什么?Point 接受三个参数,即点的坐标,并返回一个函数;出于这些目的,我们将返回值视为不透明的。用一个函数调用 Point 会将该函数作为参数传递给构成元组的对象;选择器只是返回其中一个并忽略其他。

> p1 = Point(1, 2, 3)
> =p1(x)
1
> =p1(z)
3

但是,我们并不局限于选择器;我们可以编写任何任意函数

function vlength(_x, _y, _z)
  return math.sqrt(_x * _x + _y * _y + _z * _z)
end

> =p1(vlength)
3.7416573867739

现在,虽然我们不能修改元组,但我们可以编写创建具有特定修改的新元组的函数(这类似于标准 Lua 库中的 string.gsub

function subst_x(_x)
  return function(_, _y, _z) return Point(_x, _y, _z) end
end
function subst_y(_y)
  return function(_x, _, _z) return Point(_x, _y, _z) end
end
function subst_z(_z)
  return function(_x, _y, _) return Point(_x, _y, _z) end
end

gsub 一样,这些函数不会影响原始点的内容

> p2 = p1(subst_x(42))
> =p1(x)
1
> =p2(x)
42

有趣的是,我们可以使用任何接受任意参数序列的函数

> p2(print)
42      2       3

类似地,我们可以编写组合两个点的函数

function vadd(v2)
  return function(_x, _y, _z)
    return Point(_x + v2(x), _y + v2(y), _z + v2(z))
  end
end

function vsubtract(v2)
  return function(_x, _y, _z)
    return Point(_x - v2(x), _y - v2(y), _z - v2(z))
  end
end

> =p1(vadd(p1))(print)
2       4       6

仔细检查 vaddvsubtract(以及各种替换函数)表明它们实际上是在创建一个带有闭包上值的临时函数(它们原来的参数)。但是,这些函数没有理由是临时的。实际上,我们可能希望多次使用特定的转换,在这种情况下,我们可以直接保存它

> shiftDiagonally = vadd(Point(1, 1, 1))
> p2(print)
42      2       3
> p2(shiftDiagonally)(print)
43      3       4
> p2(shiftDiagonally)(shiftDiagonally)(print)
44      4       5

这可能促使我们重新审视 vadd 的定义,以避免创建和解构参数

function subtractPoint(x, y, z)
  return function(_x, _y, _z) return _x - x, _y - y, _z - z end
end

function addPoint(x, y, z)
  return function(_x, _y, _z) return _x + x, _y + y, _z + z end
end

趁此机会,让我们添加几个其他转换

function scaleBy(q)
  return function(_x, _y, _z) return q * _x, q * _y, q * _z end
end

function rotateBy(theta)
  local sintheta, costheta = math.sin(theta), math.cos(theta)
  return function(_x, _y, _z)
    return _x * costheta - _y * sintheta, _x * sintheta + _y * costheta, _z
  end
end

注意,在 rotateBy 中,我们预先计算了正弦和余弦,以便在每次应用函数时不必调用数学库。

现在这些函数不返回一个Point;它们只是返回构成Point的值。要使用它们,我们必须显式地创建点。

> p3 = Point(p1(scaleBy(10)))
> p3(print)
10      20      30

这有点繁琐。但正如我们将会看到的,它有它的优势。

但首先,让我们再次看看addPoint。如果我们有一个变换,这很好,但如果我们想按一个特定的点移动呢?p1(addPoint(p2))显然行不通。然而,答案很简单。

> centre = Point(0.5, 0.5, 0.5)
> -- This doesn't work
> =p1(subtractPoint(centre))
stdin:2: attempt to perform arithmetic on a function value
stack traceback:
        stdin:2: in function <stdin:2>
        (tail call): ?
        (tail call): ?
        [C]: ?
> -- But this works just fine:
> =p1(centre(subtractPoint))
0.5     1.5     2.5

此外,这些新函数可以组合;实际上,我们可以将一系列变换作为一个单一的基元创建。

-- A complex transformation
function transform(centre, expand, theta)
  local shift = centre(subtractPoint)
  local exp = scaleBy(expand)
  local rot = rotateBy(theta)
  local unshift = centre(addPoint)
  return function(_x, _y, _z)
    return unshift(exp(rot(shift(_x, _y, _z))))
  end
end

> xform = transform(centre, 10, math.pi / 4)
> =p1(xform)
-6.5710678118655        14.642135623731 25.5

这带来的一个巨大好处是,一旦创建了xform,它就可以在不创建任何堆对象的情况下执行。所有内存消耗都在堆栈上。当然,这有点不诚实——创建元组(一个函数闭包和三个上值)以及创建单个变换器需要相当多的存储分配。

此外,我们还没有设法处理一些重要的语法问题,比如如何让这些元组真正被普通程序员使用。

--RiciLake

推广到任意大小 N

为了使上述方案适用于任意大小的元组,我们可以使用代码生成,如下所示——DavidManura.

function all(n, ...) return ... end     -- return all elements in tuple
function size(n) return n end           -- return size of tuple
function first(n,e, ...) return e end     -- return first element in tuple
function second(n,_,e, ...) return e end  -- return second element in tuple
function third(n,_,_,e, ...) return e end -- return third element in tuple
local nthf = {first, second, third}
function nth(n)
  return nthf[n] or function(...) return select(n+1, ...) end
end

local function make_tuple_equals(n)
  local ta, tb, te = {}, {}, {}
  for i=1,n do
    ta[#ta+1] = "a" .. i
    tb[#tb+1] = "b" .. i
    te[#te+1] = "a" .. i .. "==b" .. i
  end
  local alist = table.concat(ta, ",")
  if alist ~= "" then alist = "," .. alist end
  local blist = table.concat(tb, ",")
  if blist ~= "" then blist = "," .. blist end
  local elist = table.concat(te, " and ")
  if elist ~= "" then elist = "and " .. elist end
  local s = [[
    local t, n1 %s = ...
    local f = function(n2 %s)
      return n1==n2 %s
    end
    return t(f)
  ]]
  s = string.format(s, alist, blist, elist)
  return assert(loadstring(s))
end

local cache = {}
function equals(t)
  local n = t(size)
  local f = cache[n]; if not f then
    f = make_tuple_equals(n)
    cache[n] = f
  end
  return function(...) return f(t, ...) end
end

local function equals2(t1, t2)
  return t1(equals(t2))
end

local ops = {
  ['#'] = size,
  ['*'] = all,
}
local ops2 = {
  ["number"]   = function(x) return nth(x) end,
  ["function"] = function(x) return x end,
  ["string"]   = function(x) return ops[x] end
}

local function make_tuple_constructor(n)
  local ts = {}
  for i=1,n do ts[#ts+1] = "a" .. i end
  local slist = table.concat(ts, ",")
  local c = slist ~= "" and "," or ""
  local s =
    "local ops2 = ... " ..
    "return function(" .. slist .. ") " ..
    "  return function(f) " ..
     "    return (ops2[type(f)](f))(" ..
                 n .. c .. slist .. ") end " ..
    "end"
  return assert(loadstring(s))(ops2)
end

local cache = {}
function tuple(...)
  local n = select('#', ...)
  local f = cache[n]; if not f then
    f = make_tuple_constructor(n)
    cache[n] = f
  end
  return f(...)
end

测试

-- test suite
local t = tuple(1,nil,2,nil)
;(function(a,b,c,d) assert(a==1 and b==nil and c==2 and d==nil) end)(t(all))
;(function(a,b,c,d) assert(a==1 and b==nil and c==2 and d==nil) end)(t '*')
assert(t(size) == 4)
assert(t '#' == 4)
assert(t(nth(1)) == 1 and t(nth(2)) == nil and t(nth(3)) == 2 and
       t(nth(4)) == nil)
assert(t(1) == 1 and t(2) == nil and t(3) == 2 and t(4) == nil)
assert(t(first) == 1 and t(second) == nil and t(third) == 2)
local t = tuple(3,4,5,6)
assert(t(nth(1)) == 3 and t(nth(2)) == 4 and t(nth(3)) == 5 and
       t(nth(4)) == 6)
assert(t(first) == 3 and t(second) == 4 and t(third) == 5)
assert(tuple()(size) == 0 and tuple(3)(size) == 1 and tuple(3,4)(size) == 2)
assert(tuple(nil)(size) == 1)
assert(tuple(3,nil,5)(equals(tuple(3,nil,5))))
assert(not tuple(3,nil,5)(equals(tuple(3,1,5))))
assert(not tuple(3,nil)(equals(tuple(3,nil,5))))
assert(not tuple(3,5,nil)(equals(tuple(3,5))))
assert(tuple()(equals(tuple())))
assert(tuple(nil)(equals(tuple(nil))))
assert(tuple(1)(equals(tuple(1))))
assert(not tuple(1)(equals(tuple())))
assert(not tuple()(equals(tuple(1))))
assert(equals2(tuple(3,nil,5), tuple(3,nil,5)))
assert(not equals2(tuple(3,nil,5), tuple(3,1,5)))


-- example
function trace(f)
  return function(...)
    print("+function")
    local t = tuple(f(...))
    print("-function")
    return t(all)
  end
end
local test = trace(function (a,b,c)
  print("test",a+b+c)
end)
test(2,3,4)
--[[OUTPUT:
+function
test    9
-function
]]

评论

我认为此页面具有误导性。这些不是元组。元组只有在按值比较时才有用,这样它们就可以被索引(即可以用作表键)。如果没有这个属性,它们就比表好不了多少。有关使用内部索引树的 n 元组实现,请参见[1]。--CosminApreutesei.

另请参见


最近更改 · 偏好设置
编辑 · 历史
最后编辑于 2014 年 9 月 12 日下午 4:09 GMT (差异)