方法链包装器

lua-users home
wiki

有时,我们希望为内置类型(如字符串和函数)添加自定义方法,尤其是在使用方法链 [1][2]

("  test  "):trim():repeatchars(2):upper() --> "TTEESSTT"

(function(x,y) return x*y end):curry(2) --> (function(y) return 2*y end)

我们可以使用 debug 库(debug.setmetatable)来实现这一点 [5]。缺点是,每个内置类型都有一个通用的元表。修改此元表会导致全局副作用,这可能导致程序中独立维护的模块之间发生冲突。出于良好原因,通常不鼓励在常规代码中使用 debug 库中的函数。许多人会避免注入这些全局元表,而另一些人则觉得它们过于方便而无法避免 [3][6][ExtensionProposal]。甚至有人询问为什么内置类型的对象没有自己的元表 [7]

...
debug.setmetatable("", string_mt)
debug.setmetatable(function()end, function_mt)

我们可以改为仅使用独立函数

(repeatchars(trim("test"), 2)):upper()

curry(function(x,y) return x*y end, 2)

这是最简单的解决方案。简单的解决方案通常是好的。尽管如此,将某些操作视为方法调用,而另一些操作视为独立的全局函数,以及由此产生的重排,可能会产生一定的差异。

一种避免修改全局元表的解决方案是将对象包装在自己的类中,在包装器上执行方法调用链操作,然后解开对象。

示例看起来像这样

S"  test  ":trim():repeatchars(2):upper()() --> TTEESSTT

S"  TEST  ":trim():lower():find('e')() --> 2 2

S 函数将给定对象包装到包装器对象中。包装器对象上的方法调用链会就地操作被包装的对象。最后,包装器对象通过函数调用 () 进行解包。

对于返回单个值的函数,另一种解包方式是使用一元减号

-S"  test  ":trim():repeatchars(2):upper() --> TTEESSTT

为了根据字符串函数表 stringx 来定义 S,我们可以使用此代码

local stringx = {}
for k,v in pairs(string) do stringx[k] = v end
function stringx.trim(self)
  return self:match('^%s*(%S*)%s*$')
end
function stringx.repeatchars(self, n)
  local ts = {}
  for i=1,#self do
    local c = self:sub(i,i)
    for i=1,n do ts[#ts+1] = c end
  end
  return table.concat(ts)
end

local S = buildchainwrapbuilder(stringx)

buildchainwrapbuilder 函数是通用的,它实现了我们的设计模式

-- (c) 2009 David Manura. Licensed under the same terms as Lua (MIT license).
-- version 20090430
local select = select
local setmetatable = setmetatable
local unpack = unpack
local rawget = rawget

-- https://lua-users.lua.ac.cn/wiki/CodeGeneration
local function memoize(func)
  return setmetatable({}, {
    __index = function(self, k) local v = func(k); self[k] = v; return v end,
    __call = function(self, k) return self[k] end
  })
end

-- unique IDs (avoid name clashes with wrapped object)
local N = {}
local VALS = memoize(function() return {} end)
local VAL = VALS[1]
local PREV = {}

local function mypack(ow, ...)
  local n = select('#', ...)
  for i=1,n do ow[VALS[i]] = select(i, ...) end
  for i=n+1,ow[N] do ow[VALS[i]] = nil end
  ow[N] = n
end

local function myunpack(ow, i)
  i = i or 1
  if i <= ow[N] then
    return rawget(ow, VALS[i]), myunpack(ow, i+1)
  end
end

local function buildchainwrapbuilder(t)
  local mt = {}
  function mt:__index(k)
    local val = rawget(self, VAL)
    self[PREV] = val -- store in case of method call
    mypack(self, t[k])
    return self
  end
  function mt:__call(...)
    if (...) == self then -- method call
      local val = rawget(self, VAL)
      local prev = rawget(self, PREV)
      self[PREV] = nil
      mypack(self, val(prev, select(2,...)))
      return self
    else
      return myunpack(self, 1, self[N])
    end
  end
  function mt:__unm() return rawget(self, VAL) end

  local function build(o)
    return setmetatable({[VAL]=o,[N]=1}, mt)
  end
  return build
end

local function chainwrap(o, t)
  return buildchainwrapbuilder(t)(o)
end

测试套件

-- simple examples
assert(-S"AA":lower() == "aa")
assert(-S"AB":lower():reverse() == "ba")
assert(-S"  test  ":trim():repeatchars(2):upper() == "TTEESSTT")
assert(S"  test  ":trim():repeatchars(2):upper()() == "TTEESSTT")

-- basics
assert(S""() == "")
assert(S"a"() == "a")
assert(-S"a" == "a")
assert(S(nil)() == nil)
assert(S"a":byte()() == 97)
local a,b,c = S"TEST":lower():find('e')()
assert(a==2 and b==2 and c==nil)
assert(-S"TEST":lower():find('e') == 2)

-- potentially tricky cases
assert(S"".__index() == nil)
assert(S"".__call() == nil)
assert(S""[1]() == nil)
stringx[1] = 'c'
assert(S"a"[1]() == 'c')
assert(S"a"[1]:upper()() == 'C')
stringx[1] = 'd'
assert(S"a"[1]() == 'd') -- uncached
assert(S"a".lower() == string.lower)

-- improve error messages?
--assert(S(nil):z() == nil)

print 'DONE'

上述实现具有以下特点和假设

我们也可以用其他方式来表示链式调用

S{"  test  ", "trim", {"repeatchars",2}, "upper"}

S("  test  ", "trim | repeatchars(2) | upper")

但这样看起来不太常规。(注意:最后一行中的第二个参数是无点(point-free)的 [4]。)

方法链包装器第二版 - 链末的对象

我们可以改为像这样表示调用链

chain(stringx):trim():repeatchars(5):upper()('  test   ')

其中操作的对象位于最末端。这减少了忘记解包的可能性,并且允许分离和重用

f = chain(stringx):trim():repeatchars(5):upper()
print ( f('  test  ') )
print ( f('  again  ') )

实现这一点有多种方法(函数式、代码生成和虚拟机)。这里我们采用后一种方法。

-- method call chaining, take #2
-- (c) 2009 David Manura. Licensed under the same terms as Lua (MIT license).
-- version 20090501

-- unique IDs to avoid name conflict
local OPS = {}
local INDEX = {}
local METHOD = {}

-- table insert, allowing trailing nils
local function myinsert(t, v)
  local n = t.n + 1; t.n = n
  t[n] = v
end

local function eval(ops, x)
  --print('DEBUG:', unpack(ops,1,ops.n))
  local t = ops.t

  local self = x
  local prev
  local n = ops.n
  local i=1; while i <= n do
    if ops[i] == INDEX then
      local k = ops[i+1]
      prev = x  -- save in case of method call
      x = t[k]
      i = i + 2
    elseif ops[i] == METHOD then
      local narg = ops[i+1]
      x = x(prev, unpack(ops, i+2, i+1+narg))
      i = i + 2 + narg
    else
      assert(false)
    end
  end
  return x
end

local mt = {}
function mt:__index(k)
  local ops = self[OPS]
  myinsert(ops, INDEX)
  myinsert(ops, k)
  return self
end

function mt:__call(x, ...)
  local ops = self[OPS]
  if x == self then -- method call
    myinsert(ops, METHOD)
    local n = select('#', ...)
    myinsert(ops, n)
    for i=1,n do
      myinsert(ops, (select(i, ...)))
    end
    return self
  else
    return eval(ops, x)
  end
end

local function chain(t)
  return setmetatable({[OPS]={n=0,t=t}}, mt)
end

粗略的测试代码

local stringx = {}
for k,v in pairs(string) do stringx[k] = v end
function stringx.trim(self)
  return self:match('^%s*(%S*)%s*$')
end
function stringx.repeatchars(self, n)
  local ts = {}
  for i=1,#self do
    local c = self:sub(i,i)
    for i=1,n do ts[#ts+1] = c end
  end
  return table.concat(ts)
end

local C = chain
assert(C(stringx):trim():repeatchars(2):upper()("  test  ") == 'TTEESSTT')
local f = C(stringx):trim():repeatchars(2):upper()
assert(f"  test  " == 'TTEESSTT')
assert(f"  again  " == 'AAGGAAIINN')
print 'DONE'

方法链包装器第三版 - 带有作用域感知元表的词法注入

另一种想法是修改字符串元表,使字符串方法的扩展仅在词法作用域内可见。下面的代码不完美(例如,嵌套函数),但它是一个开始。示例

-- test example libraries
local stringx = {}
function stringx.trim(self)  return self:match('^%s*(%S*)%s*$') end
local stringxx = {}
function stringxx.trim(self) return self:match('^%s?(.-)%s?$') end

-- test example
function test2(s)
  assert(s.trim == nil)
  scoped_string_methods(stringxx)
  assert(s:trim() == ' 123 ')
end
function test(s)
  scoped_string_methods(stringx)
  assert(s:trim() == '123')
  test2(s)
  assert(s:trim() == '123')
end
local s = '  123  '
assert(s.trim == nil)
test(s)
assert(s.trim == nil)
print 'DONE'

scoped_string_methods 函数将给定的函数表分配给当前正在执行的函数的范围。范围内的所有字符串索引都通过该给定表进行。

以上使用了此框架代码

-- framework
local mt = debug.getmetatable('')
local scope = {}
function mt.__index(s, k)
  local f = debug.getinfo(2, 'f').func
  return scope[f] and scope[f][k] or string[k]
end
local function scoped_string_methods(t)
  local f = debug.getinfo(2, 'f').func
  scope[f] = t
end

方法链包装器第四版 - 使用 MetaLua 的词法注入

我们可以通过 MetaLua 更稳健地实现上述类似的功能。示例如下。

-{extension "lexicalindex"}

-- test example libraries
local stringx = {}
function stringx.trim(self)  return self:match('^%s*(%S*)%s*$') end

local function f(o,k)
  if type(o) == 'string' then
    return stringx[k] or string[k]
  end
  return o[k]
end

local function test(s)
  assert(s.trim == nil)
  lexicalindex f
  assert(s.trim ~= nil)
  assert(s:trim():upper() == 'TEST')
end
local s = '  test  '
assert(s.trim == nil)
test(s)
assert(s.trim == nil)

print 'DONE'

语法扩展引入了一个新的关键字 lexicalindex,它指定了一个函数,每当在当前范围内对值进行索引时都会调用该函数。

以下是相应的纯 Lua 源代码

--- $ ./build/bin/metalua -S vs.lua
--- Source From "@vs.lua": ---
local function __li_invoke (__li_index, o, name, ...)
   return __li_index (o, name) (o, ...)
end

local stringx = { }

function stringx:trim ()
   return self:match "^%s*(%S*)%s*$"
end

local function f (o, k)
   if type (o) == "string" then
      return stringx[k] or string[k]
   end
   return o[k]
end

local function test (s)
   assert (s.trim == nil)
   local __li_index = f
   assert (__li_index (s, "trim") ~= nil)
   assert (__li_invoke (__li_index, __li_invoke (__li_index, s, "trim"), "upper"
) == "TEST")
end

local s = "  test  "

assert (s.trim == nil)

test (s)

assert (s.trim == nil)

print "DONE"

lexicalindex Metalua 扩展的实现如下

-- lexical index in scope iff depth > 0
local depth = 0

-- transform indexing expressions
mlp.expr.transformers:add(function(ast)
  if depth > 0 then
    if ast.tag == 'Index' then
      return +{__li_index(-{ast[1]}, -{ast[2]})}
    elseif ast.tag == 'Invoke' then
      return `Call{`Id'__li_invoke', `Id'__li_index', unpack(ast)}
    end
  end
end)

-- monitor scoping depth
mlp.block.transformers:add(function(ast)
  for _,ast2 in ipairs(ast) do
    if ast2.is_lexicalindex then
      depth = depth - 1; break
    end
  end
end)

-- handle new "lexicalindex" statement
mlp.lexer:add'lexicalindex'
mlp.stat:add{'lexicalindex', mlp.expr, builder=function(x)
  local e = unpack(x)
  local ast_out = +{stat: local __li_index = -{e}}
  ast_out.is_lexicalindex = true
  depth = depth + 1
  return ast_out
end}

-- utility function
-- (note: o must be indexed exactly once to preserve behavior
return +{block:
  local function __li_invoke(__li_index, o, name, ...)
    return __li_index(o, name)(o, ...)
  end
}

--DavidManura

另请参阅


RecentChanges · preferences
编辑 · 历史
最后编辑于 2009 年 12 月 8 日下午 7:38 GMT (差异)