Vararg 的二等公民

lua-users home
wiki

Lua 中的 varargs "..." [1] 在 Lua 5.1 中不是 [一等公民] 对象,这会导致一些表达上的限制。这里列出了一些问题及其解决方法。

问题 #1: 保存 Vararg

Lua 5.1 中的 vararg (...) 处理有一些限制。例如,它不允许以下操作

function tuple(...) return function() return ... end end
--Gives error "cannot use '...' outside a vararg function near '...'"

(关于此的一些评论可以在 LuaList:2007-03/msg00249.html 中找到。)

你可能希望使用这样的函数来临时存储函数调用的返回值,执行一些其他操作,然后再次检索这些存储的返回值。以下函数将使用这个假设的 tuple 函数在给定函数周围添加跟踪语句

--Wraps a function with trace statements.
function trace(f)
  return function(...)
    print("begin", f)
    local result = tuple(f(...))
    print("end", f)
    return result()
  end
end

test = trace(function(x,y,z) print("calc", x,y,z); return x+y, z end)
print("returns:", test(2,3,nil))
-- Desired Output:
--   begin   function: 0x687350
--   calc    2       3       nil
--   end     function: 0x687350
--   returns:        5       nil

尽管如此,在 Lua 中仍然有方法可以实现这一点。

解决方案: {...} unpack

你可以使用表结构 {...} unpack 来实现 trace

--Wraps a function with trace statements.
function trace(f)
  return function(...)
    print("begin", f)
    local result = {f(...)}
    print("end", f)
    return unpack(result)
  end
end

test = trace(function(x,y,z) print("calc", x,y,z); return x+y, z end)
print("returns:", test(2,3,nil))
-- Output:
--   begin   function: 0x6869d0
--   calc    2       3       nil
--   end     function: 0x6869d0
--   returns:        5

不幸的是,它会丢失 nil 返回值,因为 nil 不能显式存储在表中,特别是 {...} 不会保留有关尾随 nil 的信息(这将在 StoringNilsInTables 中进一步讨论)。

解决方案: {...} unpackn

以下是对先前解决方案的改进,它可以正确处理 nil

function pack2(...) return {n=select('#', ...), ...} end
function unpack2(t) return unpack(t, 1, t.n) end

--Wraps a function with trace statements.
function trace(f)
  return function(...)
    print("begin", f)
    local result = pack2(f(...))
    print("end", f)
    return unpack2(result);
  end
end

test = trace(function(x,y,z) print("calc", x,y,z); return x+y, z end)
print("returns:", test(2,3,nil))
-- Output:
--   begin   function: 0x6869d0
--   calc    2       3       nil
--   end     function: 0x6869d0
--   returns:        5       nil

Shirik 指出的一个变体是

local function tuple(...)
  local n = select('#', ...)
  local t = {...}
  return function() return unpack(t, 1, n) end
end

解决方案: nil 占位符

以下方法将 nil 与可以存储在表中的占位符交换。这里可能不太理想,但这种方法可能在其他地方有用。

local NIL = {} -- placeholder value for nil, storable in table.
function pack2(...)
  local n = select('#', ...)
  local t = {}
  for i = 1,n do
    local v = select(i, ...)
    t[i] = (v == nil) and NIL or v
  end
  return t
end

function unpack2(t)  --caution: modifies t
  if #t == 0 then
    return
  else
    local v = table.remove(t, 1)
    if v == NIL then v = nil end
    return v, unpack2(t)
  end
end

--Wraps a function with trace statements.
function trace(f)
  return function(...)
    print("begin", f)
    local result = pack2(f(...))
    print("end", f)
    return unpack2(result)
  end
end

test = trace(function(x,y,z) print("calc", x,y,z); return x+y, z end)
print("returns:", test(2,3,nil))
-- Output:
--   begin   function: 0x687350
--   calc    2       3       nil
--   end     function: 0x687350
--   returns:        5       nil

以下是 pack2 和 unpack2 的更优化的实现

local NIL = {} -- placeholder value for nil, storable in table.
function pack2(...)
  local n = select('#', ...)
  local t = {...}
  for i = 1,n do
    if t[i] == nil then
      t[i] = NIL
    end
  end
  return t
end

function unpack2(t, k, n)
  k = k or 1
  n = n or #t
  if k > n then return end
  local v = t[k]
  if v == NIL then v = nil end
  return v, unpack2(t, k + 1, n)
end
作为一项很好的副作用,现在 unpack2 可以操作索引范围 [k, n] 而不是整个表。如果你没有指定范围,则会解包整个表。--Sergey Rozhenko,2009,Lua 5.1

另请参阅 StoringNilsInTables.

解决方案: 延续传递风格 (CPS)

如果我们使用Continuation passing style (CPS)([维基百科])如下所示,可以避免使用表格。我们可以预期这会更有效率。

function trace(f)
  local helper = function(...)
    print("end", f)
    return ...
  end
  return function(...)
    print("begin", f)
    return helper(f(...))
  end
end

test = trace(function(x,y,z) print("calc", x,y,z); return x+y, z end)
print("returns:", test(2,3,nil))
-- Output:
--   begin   function: 0x686b10
--   calc    2       3       nil
--   end     function: 0x686b10
--   returns:        5       nil

CPS 方法也用在 RiciLake 的字符串分割函数中(LuaList:2006-12/msg00414.html)。

解决方案:代码生成

另一种方法是代码生成,它为每个元组大小编译一个单独的构造函数。在构建构造函数时会有一些初始开销,但构造函数本身可以被最佳地实现。之前使用的 tuple 函数可以这样实现

local function build_constructor(n)
  local t = {}; for i = 1,n do t[i] = "a" .. i end
  local arglist = table.concat(t, ',')
  local src = "return function(" .. arglist ..
              ") return function() return " .. arglist .. " end end"
  return assert(loadstring(src))()
end
function tuple(...)
  local construct = build_constructor(select('#', ...))
  return construct(...)
end

为了避免每次存储时代码生成的开销,我们可以记忆 make_storeimpl 函数(有关背景信息,请参见 [维基百科:记忆化]FuncTables)。

function Memoize(fn)
  return setmetatable({}, {
    __index = function(t, k) local val = fn(k); t[k] = val; return val end,
    __call  = function(t, k) return t[k] end
  })
end

build_constructor = Memoize(build_constructor)

通过代码生成实现元组的更完整示例可以在 FunctionalTuples 中找到。

代码构建/记忆化技术和上面的 Memoize 函数基于 RiciLake 的一些先前相关的示例,例如 [Re: The Curry Challenge]

还要注意,如果您的包装函数引发异常,您可能还需要使用 pcallLuaList:2007-02/msg00165.html)。

解决方案:函数式,递归

以下方法纯粹是函数式的(没有表格)并且避免了代码生成。它不一定是最高效的,因为它为每个元组元素创建一个函数。

function helper(n, first, ...)
  if n == 1 then
    return function() return first end
  else
    local rest = helper(n-1, ...)
    return function() return first, rest() end
  end
end
function tuple(...)
  local n = select('#', ...)
  return (n == 0) and function() end or helper(n, ...)
end


-- TEST
local function join(...)
  local t = {n=select('#', ...), ...}
  for i=1,t.n do t[i] = tostring(t[i]) end
  return table.concat(t, ",")
end
local t = tuple()
assert(join(t()) == "")
t = tuple(2,3,nil,4,nil)
assert(join(t()) == "2,3,nil,4,nil")
print "done"

解决方案:协程

另一个想法是使用协程

do
  local function helper(...)
    coroutine.yield()
    return ...
  end
  function pack2(...)
    local o = coroutine.create(helper)
    coroutine.resume(o, ...)
    return o
  end
  function unpack2(o)
    return select(2, coroutine.resume(o))
  end
end

LuaList:2007-02/msg00142.html 中发布了一个类似的建议。但这可能效率低下(RiciLake 指出,一个最小的协程占用略多于 1k 加上 malloc 开销,在 freebsd 上总计接近 2k,最大的部分是堆栈,默认情况下为 45 个插槽 @ 12 或 16 字节)。

没有必要在每次调用时都创建一个新的协程。以下方法相当高效,递归使用尾调用

local yield = coroutine.yield
local resume = coroutine.resume
local function helper(...)
  yield(); return helper(yield(...))
end
local function make_stack() return coroutine.create(helper) end

-- Example
local stack = make_stack()
local function trace(f)
  return function(...)
    print("begin", f)
    resume(stack, f(...))
    print("end", f)
    return select(2, resume(stack))
  end
end

解决方案:C 闭包中的 Upvalues

元组可以在 C 中实现为一个闭包,该闭包包含元组元素作为 Upvalues。这在 Programming In Lua, 2nd Ed 的第 27.3 节中进行了演示 [2]

基准测试

以下基准测试比较了上述解决方案的速度。

-- Avoid global table accesses in benchmark.
local time = os.time
local unpack = unpack
local select = select

-- Benchmarks function f using chunks of nbase iterations for duration
-- seconds in ntrials trials.
local function bench(duration, nbase, ntrials, func, ...)
  assert(nbase % 10 == 0)
  local nloops = nbase/10
  local ts = {}
  for k=1,ntrials do
    local t1, t2 = time()
    local nloops2 = 0
    repeat
      for j=1,nloops do
        func(...) func(...) func(...) func(...) func(...)
        func(...) func(...) func(...) func(...) func(...)
      end
      t2 = time()
      nloops2 = nloops2 + 1
   until t2 - t1 >= duration
    local t = (t2-t1) / (nbase * nloops2)
    ts[k] = t
  end
  return unpack(ts)
end

local function print_bench(name, duration, nbase, ntrials, func, ...)
  local fmt = (" %0.1e"):rep(ntrials)
  print(string.format("%25s:" .. fmt, name,
                      bench(duration, nbase, ntrials, func, ...) ))
end

-- Test all methods.
local function test_suite(duration, nbase, ntrials)
  print("name" .. (", t"):rep(ntrials) .. " (times in sec)")

  do -- This is a base-line.
    local function trace(f)
      return function(...) return f(...) end
    end
    local f = trace(function() return 11,12,13,14,15 end)
    print_bench("(control)", duration, nbase, ntrials, f, 1,2,3,4,5)
  end

  do
    local function trace(f)
      local function helper(...)
        return ...
      end
      return function(...)
        return helper(f(...))
      end
    end
    local f = trace(function() return 11,12,13,14,15 end)
    print_bench("CPS", duration, nbase, ntrials, f, 1,2,3,4,5)
  end

  do
    local yield = coroutine.yield
    local resume = coroutine.resume
    local function helper(...)
      yield(); return helper(yield(...))
    end
    local function make_stack() return coroutine.create(helper) end
    local stack = make_stack()
    local function trace(f)
      return function(...)
        resume(stack, f(...))
        return select(2, resume(stack))
      end
    end
    local f = trace(function() return 11,12,13,14,15 end)
    print_bench("Coroutine", duration, nbase, ntrials, f, 1,2,3,4,5)
  end

  do
    local function trace(f)
      return function(...)
        local t = {f(...)}
        return unpack(t)
      end
    end
    local f = trace(function() return 11,12,13,14,15 end)
    print_bench("{...} and unpack", duration, nbase, ntrials, f, 1,2,3,4,5)
  end

  do  
    local function trace(f)
      return function(...)
        local n = select('#', ...)
        local t = {f(...)}
        return unpack(t, 1, n)
      end
    end
    local f = trace(function() return 11,12,13,14,15 end)
    print_bench("{...} and unpack with n", duration, nbase, ntrials,
                f, 1,2,3,4,5)
  end

  do
    local NIL = {}
    local function pack2(...)
      local n = select('#', ...)
      local t = {...}
      for i=1,n do
        local v = t[i]
        if t[i] == nil then t[i] = NIL end
      end
      return t
    end
    local function unpack2(t)
      local n = #t
      for i=1,n do
        local v = t[i]
        if t[i] == NIL then t[i] = nil end
      end
      return unpack(t, 1, n)
    end
    local function trace(f)
      return function(...)
        local t = pack2(f(...))
        return unpack2(t)
      end
    end
    local f = trace(function() return 11,12,13,14,15 end)
    print_bench("nil Placeholder", duration, nbase, ntrials,
                f, 1,2,3,4,5)
  end

  do
    -- This is a simplified version of Code Generation for comparison.
    local function tuple(a1,a2,a3,a4,a5)
      return function() return a1,a2,a3,a4,a5 end
    end
    local function trace(f)
      return function(...)
        local t = tuple(f(...))
        return t()
      end
    end
    local f = trace(function() return 11,12,13,14,15 end)
    print_bench("Closure", duration, nbase, ntrials, f, 1,2,3,4,5)
  end

  do
    local function build_constructor(n)
      local t = {}; for i = 1,n do t[i] = "a" .. i end
      local arglist = table.concat(t, ',')
      local src = "return function(" .. arglist ..
                  ") return function() return " .. arglist .. " end end"
      return assert(loadstring(src))()
    end
    local cache = {}
    local function tuple(...)
      local n = select('#', ...)
      local construct = cache[n]
      if not construct then
        construct = build_constructor(n)
        cache[n] = construct
      end
      return construct(...)
    end
    local function trace(f)
      return function(...)
        local t = tuple(f(...))
        return t()
      end
    end
    local f = trace(function() return 11,12,13,14,15 end)
    print_bench("Code Generation", duration, nbase, ntrials,
                f, 1,2,3,4,5)
  end

  do
    local function helper(n, first, ...)
      if n == 1 then
        return function() return first end
      else
        local rest = helper(n-1, ...)
        return function() return first, rest() end
      end
    end
    local function tuple(...)
      local n = select('#', ...)
      return (n == 0) and function() end or helper(n, ...)
    end
    local function trace(f)
      return function(...)
        local t = tuple(f(...))
        return t()
      end
    end
    local f = trace(function() return 11,12,13,14,15 end)
    print_bench("Functional, Recursive", duration, nbase, ntrials,
                f, 1,2,3,4,5)
  end

  -- NOTE: Upvalues in C Closure not benchmarked here.

  print "done"
end

test_suite(10, 1000000, 3)
test_suite(10, 1000000, 1) -- recheck

结果

(Pentium4/3GHz)
name, t, t, t (times in sec)
                (control): 3.8e-007 3.8e-007 4.0e-007
                      CPS: 5.6e-007 6.3e-007 5.9e-007
                Coroutine: 1.7e-006 1.7e-006 1.7e-006
         {...} and unpack: 2.2e-006 2.2e-006 2.4e-006
  {...} and unpack with n: 2.5e-006 2.5e-006 2.5e-006
          nil Placeholder: 5.0e-006 4.7e-006 4.7e-006
                  Closure: 5.0e-006 5.0e-006 5.0e-006
          Code Generation: 5.5e-006 5.5e-006 5.5e-006
    Functional, Recursive: 1.3e-005 1.3e-005 1.3e-005
done

CPS 最快,其次是协程(两者都在堆栈上运行)。表比协程方法花费更多时间,尽管如果我们没有在 `resume` 上使用 `select`,协程可能会更快。闭包的使用速度要慢几倍(包括使用代码生成进行泛化),甚至慢一个数量级(如果使用函数式、递归进行泛化)。

对于元组大小为 1,我们得到

name, t, t, t (times in sec)
                (control): 2.9e-007 2.8e-007 2.7e-007
                      CPS: 4.3e-007 4.3e-007 4.3e-007
                Coroutine: 1.4e-006 1.4e-006 1.4e-006
         {...} and unpack: 2.0e-006 2.2e-006 2.2e-006
  {...} and unpack with n: 2.4e-006 2.5e-006 2.4e-006
          nil Placeholder: 3.3e-006 3.3e-006 3.3e-006
                  Closure: 2.0e-006 2.0e-006 2.0e-006
          Code Generation: 2.2e-006 2.5e-006 2.2e-006
    Functional, Recursive: 2.5e-006 2.4e-006 2.2e-006
done

对于元组大小为 20,我们得到

name, t, t, t (times in sec)
                (control): 8.3e-007 9.1e-007 9.1e-007
                      CPS: 1.3e-006 1.3e-006 1.1e-006
                Coroutine: 2.7e-006 2.7e-006 2.7e-006
         {...} and unpack: 3.0e-006 3.2e-006 3.0e-006
  {...} and unpack with n: 3.7e-006 3.3e-006 3.7e-006
          nil Placeholder: 1.0e-005 1.0e-005 1.0e-005
                  Closure: 1.8e-005 1.8e-005 1.8e-005
          Code Generation: 1.9e-005 1.8e-005 1.9e-005
    Functional, Recursive: 5.7e-005 5.7e-005 5.8e-005
done

请注意,表构造方法的时间相对于元组大小变化相对较小(由于构造表的初始开销)。相反,闭包的使用会导致运行时间随元组大小变化更显著。

问题 #2: 合并列表

问题:给定两个可变长度列表(例如,两个函数 `f` 和 `g` 的返回值,它们都返回多个值),将它们组合成一个列表。

这可能是一个问题,因为 Lua 的行为是丢弃函数的所有返回值,除了第一个返回值,除非它是列表中的最后一个项目。

local function f() return 1,2,3 end
local function g() return 4,5,6 end
print(f(), g()) -- prints 1 4 5 6

除了将列表转换为表等对象(通过上面问题 #1 中的方法)的明显解决方案之外,还有一些方法可以使用函数调用来实现这一点。

解决方案

以下方法通过一次只预先添加一个元素并延迟评估其中一个列表来递归地组合列表。

local function helper(f, n, a, ...)
  if n == 0 then return f() end
  return a, helper(f, n-1, ...)
end
local function combine(f, ...)
  local n = select('#', ...)
  return helper(f, n, ...)
end

-- TEST
local function join(...)
  local t = {n=select('#', ...), ...}
  for i=1,t.n do t[i] = tostring(t[i]) end
  return table.concat(t, ",")
end
local function f0() return end
local function f1() return 1 end
local function g1() return 2 end
local function f3() return 1,2,3 end
local function g3() return 4,5,6 end
assert(join(combine(f0, f0())) == "")
assert(join(combine(f0, f1())) == "1")
assert(join(combine(f1, f0())) == "1")
assert(join(combine(g1, f1())) == "1,2")
assert(join(combine(g3, f3())) == "1,2,3,4,5,6")
print "done"

问题 #3: 选择列表中的前 N 个元素

问题:返回一个列表,该列表包含另一个列表中的前 N 个元素。

`select` 函数允许选择列表中的最后 N 个元素,但没有内置函数用于选择前 N 个元素。

解决方案

local function helper(n, a, ...)
  if n == 0 then return end
  return a, helper(n-1, ...)
end
local function first(k, ...)
  local n = select('#', ...)
  return helper(k, ...)
end

-- TEST
local function join(...)
  local t = {n=select('#', ...), ...}
  for i=1,t.n do t[i] = tostring(t[i]) end
  return table.concat(t, ",")
end
local function f0() return end
local function f1() return 1 end
local function f8() return 1,2,3,4,5,6,7,8 end
assert(join(first(0, f0())) == "")
assert(join(first(0, f1())) == "")
assert(join(first(1, f1())) == "1")
assert(join(first(0, f8())) == "")
assert(join(first(1, f8())) == "1")
assert(join(first(2, f8())) == "1,2")
assert(join(first(8, f8())) == "1,2,3,4,5,6,7,8")
print "done"

注意:如果元素数量固定,则解决方案更简单。

local function firstthree(a,b,c) return a,b,c end
assert(join(firstthree(f8())) == "1,2,3")  -- TEST

代码生成方法可以基于此。

问题 #4: 将一个元素追加到列表

问题:将一个元素追加到列表中。

请注意,将一个元素预先添加到列表中很简单:`{a, ...}`

解决方案

local function helper(a, n, b, ...)
  if   n == 0 then return a
  else             return b, helper(a, n-1, ...) end
end
local function append(a, ...)
  return helper(a, select('#', ...), ...)
end

注意:如果元素数量固定,则解决方案更简单。

local function append3(e, a, b, c) return a, b, c, e end

问题 #5: 反转列表

问题:反转列表。

解决方案

local function helper(n, a, ...)
  if n > 0 then return append(a, helper(n-1, ...)) end
end
local function reverse(...)
  return helper(select('#', ...), ...)
end

注意:如果元素数量固定,则解决方案更简单。

local function reverse3(a,b,c) return c,b,a end

问题 #6: map 函数

问题:实现列表上的 map [3] 函数。

解决方案

local function helper(f, n, a, ...)
  if n > 0 then return f(a), helper(f, n-1, ...) end
end
local function map(f, ...)
  return helper(f, select('#', ...), ...)
end

问题 #7: filter 函数

问题:实现列表上的 filter [4] 函数。

解决方案

local function helper(f, n, a, ...)
  if n > 0 then
    if f(a) then return a, helper(f, n-1, ...)
    else         return    helper(f, n-1, ...) end
  end
end
local function grep(f, ...)
  return helper(f, select('#', ...), ...)
end

问题 #8: 遍历 Varargs

问题:遍历可变参数中的所有元素。

解决方案

for n=1,select('#',...) do
  local e = select(n,...)
end

如果您不需要 nil 元素,也可以使用以下方法

for _, e in ipairs({...}) do
   -- something with e
end

如果您希望使用迭代器函数而不每次都创建垃圾表,可以使用以下方法

do
  local i, t, l = 0, {}
  local function iter(...)
    i = i + 1
    if i > l then return end
    return i, t[i]
  end

  function vararg(...)
    i = 0
    l = select("#", ...)
    for n = 1, l do
      t[n] = select(n, ...)
    end
    for n = l+1, #t do
      t[n] = nil
    end
    return iter
  end
end

for i, v in vararg(1, "a", false, nil) do print(i, v) end -- test
-- Output:
--   1	1
--   2	"a"
--   3	false
--   4	nil

其他评论

(无)


--DavidManura,2007 年,Lua 5.1

另请参阅


RecentChanges · preferences
edit · history
最后编辑于 2017 年 4 月 1 日下午 7:50 GMT (diff)