通用输入算法

lua-users home
wiki

Lua 中的通用函数和算法

本文介绍了 func 库的功能,该库旨在使输入操作更直观。源代码可以在 Files:wiki_insecure/func.lua 中找到。

有两个特殊的输入迭代器,numbers()words(),它们的工作原理类似于非常有用的 io.lines() 迭代器。要打印标准输入中找到的所有单词

-- words.lua
require 'func'
for w in words() do
  print(w)
end
要在一些文本上测试它,您可以在 OS 命令提示符下输入
$ lua words.lua < test.txt
打印迭代器生成的数值是一个非常常见的操作,因此 func 提供了非常方便的函数 printall()。它将序列的所有成员写入标准输出。默认情况下,它每行输出 7 个项目,用空格隔开,但您可以选择更改这些设置。在这种情况下,我们希望每个值都在单独的行上
printall(words(),'\n')
numbers() 创建一个包含其输入中找到的所有数字的序列。例如,要对所有输入数字求和
require 'func'
local s 
for x in numbers() do
  s = s + x
end
print(s)
求和是分析数据时的一个常见操作,因此 func 定义了一个通用的 sum() 函数。它返回总和和字段数量,因此很容易计算平均值。
local s,n = sum(numbers())
print('average value =',s/n)
请注意,这些迭代器会查找适当的模式,因此它们不依赖于单词或数字是否用空格隔开。numbers() 将在文件中找到所有看起来像数字的东西,并安全地忽略任何其他内容。因此,它对于包含大量注释的数据或输出文件很有用。这些迭代器接受一个可选的额外参数,它可以是文件或字符串。例如,要打印作为命令行参数传递的文件中的总和和项目数量
f = io.open(arg[1])
print(sum(numbers(f)))
f:close()
将迭代器的输出收集为一个表很有用。由于它非常简单且具有指导意义,这里是一个简化的 copy() 定义
function copy(iter)
  local res = {}
  local k = 1
  for v in iter do
     res[k] = v
     k = k + 1
  end
  return res
end
下一个示例创建了一个包含字符串中找到的所有数字的数组。显然,这是一个微不足道的案例,但通常需要从字符串中提取数字,这可能很棘手。特别是,这确保了它们实际上被转换了 - 我不止一次被简单的事实所困扰,即 arr['1'] 和 arr[1] 并不相同!
t = copy(numbers '10 20 30') -- will be {10,20,30}
s = sum(list(t))             -- will be 60
请注意表序列适配器 list(),它允许将表用作序列。使用这些函数对数组进行操作很常见,因此如果您确实将表传递给它们,list() 将自动被假设。要以特定格式打印数字数组,可以使用类似 printall(t,' ',5,'%7.3f') 的内容来格式化它们。以下是系统命令 sort 的实现,它使用 printall() 函数输出序列的每个值。我不能简单地说 table.foreach(t,print),因为该操作会同时传递索引和值,因此我实际上也会得到行号!
t = copy(io.lines())
table.sort(t)
printall(t,'\n')   -- try table.foreach(t,print) and see!
使用 sort() 函数,这将变成一行代码
printall(sort(io.lines()),'\n')
可以使用slice()迭代序列的一部分。它接受一个迭代器、一个起始索引和一个项目计数。例如,这是一个简单的head命令版本;它显示输入的前十行。
printall(slice(io.lines(),1,10),'\n')
有时我们只想计算一个序列;例如,这是一个完整的脚本,用于计算文件中所有单词的数量
require 'func'
print(count(words()))
在这种形式下,count() 并不是那么有用。但它可以接受一个函数来选择要计数的项目。例如,这让我大致了解了 Lua 文件中共有多少个公共函数。(如果我不将匹配限制在开头,它也会拾取本地函数和匿名函数)
require 'func'
print(count(io.lines(),matching '^%s*function'))
其中matching() 是以下简单函数。它创建一个闭包(绑定到本地上下文的函数),该函数在序列中的每个项目上调用
function matching(s)
  local strfind = string.find
  return function(v)
    return strfind(v,s)
  end
end

当然,您可以在这些操作中使用任何序列。如果您加载了非常有用的 lfs(Lua 文件系统)库,那么t = copy_if(lfs.dir(path),matching '%.cpp$') 将用path 中所有扩展名为.cpp 的文件填充一个列表。

修改count() 输入的另一种有用方法是使用unique()

-- number of unique words in a file
print(count(unique(words())))
unique() 的实现方式与通常的方式不同,通常的方式需要先对序列进行排序。相反,它使用count_map() 创建一个映射,其中键是项目,值是计数。一旦我们有了keys(),剩下的就很容易了,它是list() 的互补兄弟
function unique(iter)
  local t = count_map(iter)
  return keys(t)
end
经典的“统计文件中单词出现的次数”是
table.foreach(count_map(words()),print)
在比较两个序列时,将它们join()在一起很有用。这将打印出两个文件之间的差异
for x,y in join(numbers(f1),numbers(f2)) do
  print(x-y)
end

Lua 中的 AWK 编程风格

在我发现 Lua 之前,AWK 是我最喜欢的用于操作文本文件的语言。(我甚至设法用“AWK 是 Excel 的命令行等效项”的口号说服了一些同事。)为了让您体验一下,这里有一个完整的 AWK 程序,用于打印文件的第一个和第三列,并使用第四列进行缩放 - 请注意,对所有行的循环是隐式的
{ print $1/$4, $3/$4 }
func 库为此目的提供了迭代器fields()。以下是等效的 Lua 代码
for x,y,z in fields{1,3,4} do
   print(x/z,y/z)
end
这是我目前最喜欢的单行代码。它计算第 7 列中大于 44000 的值有多少个,并且速度大约是等效的 AWK 程序(使用 MAWK 运行)的一半。考虑到 AWK 对其专门任务的优化程度,这还不错!
print(count(fields{7},greater_than(44000)))
{ if ($7 > 44000) k++ } END { print(k) }
fields() 可以使用任何输入分隔符。这从一个逗号分隔的文件中读取一组值 - 注意,传递 n 而不是字段 ID 列表等同于 {1,2,...n}。
for x,y in fields({1,2},',',f) do ...
for x,y in fields(2,',',f) do ...  --equivalent--

性能和表达效率

我认为很明显,使用这种泛型编程风格可以非常简洁地表达常见的操作,但人们在这一点上往往有两个保留意见,这当然是我在 C++ 中使用 STL 的经验。第一个反对意见是函数式风格效率更低。理论上当然是这样,但在实践中效率有多低?例如,以下是使用序列 random() 创建一个包含随机值的表的记录
> tt = copy(random(10000))
> = sum(tt)
5039.542771691  10000
这些操作在我的老旧笔记本电脑上几乎是瞬间完成的,我只有在 1e5 个项目时才开始注意到。对于 1e6 个项目,第一个操作需要 2.14 秒,而显式循环需要 2.08 秒!如果我小心地使用局部变量,这将下降到 1.92,因此最佳的显式版本是
local t = {}
local random = math.random
for i = 1,1e6 do
   t[i] = random()
end
这种情况表明,没有令人信服的速度优势来证明这样做是正确的。(我选择这个例子正是因为它不涉及文件 I/O,这往往会主导 words()numbers() 的运行时间。)优势在于,代码出错的可能性更小;泛型编程人员认为显式循环“乏味且容易出错”,正如 Stroustrup 所说。

第二个反对意见是,它会导致奇怪且不自然的代码。这在 C++ 中当然可能是这样,C++(让我们面对现实)并不真正适合函数式风格;没有闭包,严格的静态类型不断地妨碍,要求所有东西都是模板。这种风格更适合 Lua - 在 C++ 中这样做不会有一半的可读性,即使使用 Boost Lambda 库也是如此

-- sum of squares of input data using an internal iterator
for_each(numbers(),function(v)
    s = s + v*v
end)
-- sum of squares of input data using an external iterator
for v in numbers() do
    s = s + v*v
end
这个想法不是要替换所有循环,而是要替换常见的泛型模式。这样的代码变得更容易阅读,因为任何显式循环都会更加突出。Lua 特别适合这种风格,这种风格在 C++ 中往往显得强迫。

编写自定义输入对象

如果 f 不是字符串,那么 words(f) 将使用文件对象 f。事实上,f 可以是任何具有 read 方法的对象。代码所假设的只是 f:read() 将返回下一行输入文本。以下是一个更复杂的示例,我创建了一个类 Files,它允许我们从文件列表中读取。显而易见的应用是模仿 AWK 的行为,其中命令行上的每个文件都成为标准输入的一部分。
Files = {}
 
function Files.create(list)
   local files = {}
   files.list = {}
   local n = table.getn(list)
   for i = 1,n do
      files.list[i] = list[i]
   end
   files.open_next = Files.open_next
   files.read = Files.read
   files:open_next()
   return files
end
 
function Files:open_next()
   if self.f then self.f:close() end
   local nf = table.remove(self.list,1)
   if nf then
      self.f = io.open(nf)
      return true
   else
      self.f = nil
      return false
   end
end
 
function Files:read()
  local ret = self.f:read()
  if not ret then
     if not self:open_next() then return nil end
     return self.f:read()
  else
     return ret
  end
end
我需要解释一个明显的矛盾。在赞扬无循环编程的乐趣之后,Files.create() 中有一个经典的复制表格循环。Lua 程序被传递一个名为 arg 的全局表,该表包含命令行参数,arg[1]arg[2] 等。但也有 arg[0],即脚本名称,以及 arg[-1],即实际的程序名称。所讨论的显式循环是为了确保我们不会复制这些字段!
files = Files.create(arg)
printall(words(files))

实现和进一步开发的说明

func 的大部分内容都是对同一主题的直接变体;迭代器和函数作为闭包。PiL 的第 7.1 节 [ref?] 对这些问题进行了很好的解释,我使用 allwords 示例作为 words()numbers() 的基础。fields() 最初以一种简单的方式实现,依次获取每个字段,但后来通过创建自定义正则表达式实现为对 string.find() 的一次调用。例如,如果需要字段 1 和 3,用逗号分隔,则正则表达式如下 - 字段定义为任何不是逗号的东西,我们使用 () 捕获所需的字段。
'%s*([^,]+),[^,]+,([^,])'
序列的概念非常通用,这意味着可以轻松地将 func 操作与任何提供迭代器的库一起使用。这通常会极大地简化代码。例如,以下是如何使用 luasql 的方法。考虑访问查询结果所有行的规范方法
cur = con:execute 'SELECT * FROM [Event Summaries]'
mag = -9
row = cur:fetch({},'a')
while row do
  if row.Magnitude > mag then 
     mag = row.Magnitude
  end
  row = cur:fetch(row,'n')
end
cur:close()

我可以通过创建一个跟踪 row 的迭代器来简化此操作

function rows(cursor)
  local row = {}
  return function()
    return cursor:fetch(row,'a')
  end
end

for row in rows(cur) do
   if row.Magnitude > mag then 
      mag = row.Magnitude    
   end
end
这已经是更好的循环了,因为我们不必两次调用 cursor:fetch 并且不必处理本地 row。我们还可以实现等效于 fields 的内容
function column(fieldname,cursor)
  local row = {}
  return function()
    row = cur:fetch(row,'a')
    if not row then return nil 
    else return row[fieldname]
    end
  end
end

local minm,maxm = minmax(column('Magnitude',cur))
不再有任何显式循环!当然,使用 SQL WHERE 子句来约束序列通常更有效。以下方法有效,但不是完成这项工作的最佳方法
print(count(column('Magnitude',cur),greater_than(2)))

-- SteveDonovan


最近更改 · 偏好设置
编辑 · 历史记录
最后编辑于 2007 年 7 月 21 日下午 6:49 GMT (差异)