优化后的 Str Rep

lua-users home
wiki


[!] 版本说明: 以下代码适用于旧版本的 Lua,Lua 4。它在 Lua 5 中无法运行。

以下是一个 lua 版本的 strrep 函数,它比用 C 编写的原始版本更快。平均而言,它快了大约 3 倍。

所有测试都在一台奔腾 II 上进行,运行单用户模式下的 linux,有足够的内存来防止交换,并且使用 Lua 版本 4.0。

该算法的灵感来自 RobertoIerusalimschy 的 LTN 9,由 LuizCarlosSilveira 编写。

该算法的本质是尽可能减少连接次数。

仅仅出于好奇:与原始的 strrep 相比,lua strrepO() 函数存在疑问,但需要进一步研究。


绘制它们的“重复次数”x“时间”曲线,我们可以观察到这两个函数的行为
(时间以秒为单位,重复次数以字节为单位)


以下是一个包含该算法的程序。它被用于测试实现的正确性和生成用于绘制曲线的數據。

function log2(n)
    local _n = 2
    local x = 1
    if (_n < n) then
        repeat
            x = x + 1
            _n = _n + _n
        until (_n >= n)
    elseif (_n > n) then
        if (n == 1) then
            return 0
        else
            return nil
        end
    end 
    if (_n > n) then
        return x-1
    else
        return x
    end 
end 
    
function get_bits(n)
    local bits = {}
    local rest = n
    repeat
        local major_bit = log2(rest)
        rest = rest - 2^major_bit
        bits[major_bit] = 1
        if (bits.count == nil) then
            bits.count = major_bit
        end
    until (rest == 0)
    return bits
end



function fast_strrep(str, times)
    local bits = get_bits(times)
    local strs = {[0] = str}

    local count = bits.count

    for i = 1, count do
        strs[i] = strs[i-1] .. strs[i-1]
    end

    local result = ''
    for i = 0, count do
        if (bits[i]) then
            result = result .. strs[i]
        end
    end

    return result

end

for numreps = 1024, 30*1024*1024, 1024*64 do

    a = nil
    b = nil
    collectgarbage()

    start = clock()
    a = fast_strrep("a", numreps)
    print("L:"..numreps.." "..(clock() - start))
    start = clock()
    b = strrep("a", numreps)
    print("C:"..numreps.." "..(clock() - start))

    if (a~=b) then
        print("the algorithm is wrong!")
    else
        print("ok")
    end

    flush(_STDOUT)

end

        


我并不惊讶你的版本更快,实际上;lua 库版本的 strrep(至少 v 4.0)在函数调用中存在非常大的开销(每个字符一次),而连接函数没有这种开销。我发现奇怪的是,库版本没有简单地计算出它需要多少个字符,为它们分配内存,然后重复地使用 memcpy 复制源字符串,我认为这会快一个数量级,并且是 O(MN + M)(M 个长度为 N 的字符串的副本)。但我猜想,他们可能并没有真正考虑过连接兆字节大小的数据。

你使用的算法让人想起用于计算指数的最小乘法算法。这里有一个版本,它利用了 Lua 将 a .. b .. c 优化为单个操作的事实,并且还避免了创建临时向量。我认为你会发现它比你的算法快大约两倍(而且代码更短)。它也可能作为一个例子,说明如何在不造成太多痛苦的情况下进行类似位操作的操作。 -- RiciLake

  -- Suppose that x = b[n]*2^n + b[n-1]*2^(n-1) + ... + b[0]*2^0
  --   (where every b[i] is either 0 or 1)
  -- This is exactly equivalent to:
  --    b[0] + 2 * (b[1] + 2 * (b[2] + (... + b[n])))
  -- So we've effectively eliminated all the multiplications, replacing them with doubling.

  -- Now, x * y (for any y) can be calculated by distributing multiplication over the
  -- above, which effectively replaces every b[i] with b[i] * y. However, every b[i]
  -- is either 0 or 1, so the product is either 0 or y.

  -- Now, if k is an integer and str1 and str2 are strings, and we write:
  --   str1 + str2       for the concatenation of str1 and str2
  --   k * str1          for "k copies of str1 concatenated"
  -- we can see that we have + and * are "just like" integer arithmetic in the sense that
  -- + and * are commutative and associative, and * distributes over +. So the equivalence
  -- continues to work, except that every term must be either "" (for 0) or y (the string).

  -- All that is left is to compute the expression from the inside out: each step is
  -- either 2 * r or y + 2 * r, where r is the cumulated value and y is the original string.
  -- In string terms, we can write these as result .. result (2 * r) and
  -- result .. result .. str (2 * r + y)

  -- We could use the same idea to compute integer exponents in the minimum number of
  -- multiplications, using * and ^ instead of + and * (which is where this algorithm
  -- comes from.)

  -- This makes use of the fact that Lua optimises a .. b .. c into a single concatenation.
  -- With a bit more work, we could use any base we wanted to, not just base 2. But it would
  -- require more options in the if statement.

function another_strrep(str, times)
  local result = ""
  local high_bit = 1
  while high_bit < times do high_bit = high_bit * 2 end

  -- at this point, high_bit is the largest (integral) power of 2 smaller than times
  -- (unless times < 1 in which case high_bit is 1)
  -- The computation of highbit could be:
  --   local high_bit = 2 ^ floor(log(times) / log(2))
  -- which is probably faster but requires the math library

  -- we are now going to work through times, bit by bit, making use of the above formula:

  while high_bit >= 1 do
    if high_bit <= times then           -- the bit is 1 if times is >= high_bit
      times = times - high_bit          -- we "turn it off" for the next iteration
      result = result .. result .. str  -- and the next step is 2 * r + y
    else                                -- the bit is 0
      result = result .. result         -- so the next step is 2 * r
    end
    high_bit = high_bit / 2             -- Now go for the next bit
  end
  return result
end


你的算法真的很棒。谢谢。你们两个(你和我)的想法几乎是一样的,对吧?你做了一个我当时没有意识到如何做到的出色优化,那就是防止创建辅助片段。我绘制了这三种算法的曲线,你可以看到下面。对于这张图中绘制的数据,这三个函数的精确平均关系是
luiz/rici = 1.41  (rici is  1.41  times faster than luiz)
c/luiz    = 2.98  (luiz is  2.98  times faster than c)
c/rici    = 4.19  (rici is  4.19  times faster than c)
        

由于你的算法更快(并且似乎使用更少的内存),它才是应该出现在这个页面上的算法。除非你讨厌这个想法,否则我会删除我的算法,把位置留给你。但是,在此之前,我想问你一个问题:你能在你的代码中添加一些注释,以便更清楚地说明算法吗?由于你的优化,算法背后的想法变得模糊了... --LuizCarlosSilveira

好的,我强迫症地加了注释。我希望它很清楚;有时我觉得代码本身更清晰。这个函数没有达到最佳状态,因为它比必要的多进行了一次连接……我试图让代码更短,并依赖 Lua 非常快速地执行 "" .. "" .. str

为了好玩,以及为了演示一些东西(我不确定是什么),我加入了上述算法的十进制版本。与其使用 case 语句并进行逐位计算,我使用 gsub 来执行循环并将重复计数转换为字符串以计算出数字。表查找(可能)比一串 if 语句快得多,所以我使用了它。%state 是一个标准技巧,用于解决 Lua 4.0 没有真正的闭包的问题。

我不声称这个函数很容易阅读,但我的测试表明它更快。(抱歉,没有其他评论,但想法是一样的,所以你应该能够弄明白。:-))这只是表明如果你足够扭曲,你可以做到什么。我还有另一个类似的例子:一个我编写的 join 函数,它会延迟编译子例程来执行相同类型的指数分解问题;即使它必须组合和编译函数,它最终比朴素的 join 快得多。当然,这些函数必须被记忆才能利用这一点。我会尝试发布那个函数。 -- RiciLake

do
  local concats = {
    ["0"] = function(a, b) return a .. a .. a .. a .. a .. a .. a .. a .. a .. a end,
    ["1"] = function(a, b) return a .. a .. a .. a .. a .. a .. a .. a .. a .. a
                                    .. b end,
    ["2"] = function(a, b) return a .. a .. a .. a .. a .. a .. a .. a .. a .. a
                                    .. b .. b end,
    ["3"] = function(a, b) return a .. a .. a .. a .. a .. a .. a .. a .. a .. a
                                    .. b .. b .. b end,
    ["4"] = function(a, b) return a .. a .. a .. a .. a .. a .. a .. a .. a .. a
                                    .. b .. b .. b .. b end,
    ["5"] = function(a, b) return a .. a .. a .. a .. a .. a .. a .. a .. a .. a
                                    .. b .. b .. b .. b .. b end,
    ["6"] = function(a, b) return a .. a .. a .. a .. a .. a .. a .. a .. a .. a
                                    .. b .. b .. b .. b .. b .. b end,
    ["7"] = function(a, b) return a .. a .. a .. a .. a .. a .. a .. a .. a .. a
                                    .. b .. b .. b .. b .. b .. b .. b end,
    ["8"] = function(a, b) return a .. a .. a .. a .. a .. a .. a .. a .. a .. a
                                    .. b .. b .. b .. b .. b .. b .. b .. b end,
    ["9"] = function(a, b) return a .. a .. a .. a .. a .. a .. a .. a .. a .. a
                                    .. b .. b .. b .. b .. b .. b .. b .. b .. b end,
  }

  function decimal_strrep(str, times)
    local state = {r = "" }
    local concats = %concats
    times = tostring(times)
    if strfind(times, "^[0-9]+$") then
      gsub(times, "(.)",
           function(digit)
             %state.r = %concats[digit](%state.r, %str)
           end)
    end
    return state.r
  end
end

        


顺便说一句,测试并不完美,因为时间会随着重复计数的二进制展开中 1 的数量而变化。(我认为我的版本对这一点稍微不那么敏感,但它仍然是一个因素。)你使用的测试计数在它们的二进制展开中大多是 0。 -- RiciLake


我相信这就是你为什么说你的算法应该比我的快两倍的原因。实际上,当我使用仅一位开启形成的重复次数时,这是内存消耗的最佳情况。在最坏的情况下,我的算法似乎使用了你算法两倍的内存。我同意我做的测量应该被审查,但这个页面才刚刚开始... --LuizCarlosSilveira


说得有道理。尝试用比“a”更长的字符串进行测试也很有趣。我一直没有弄清楚如何充分地对 Lua 程序进行基准测试,因为垃圾回收时间会随着总堆大小而变化。最好运行几次程序以使堆大小稳定下来,然后进行计时。 --RiciLake

最近更改 · 偏好设置
编辑 · 历史记录
最后编辑于 2017 年 9 月 21 日下午 9:29 GMT (差异)