符号微分

lua-users home
wiki

这最初是为 PenlightLibraries 做的一个小练习,但后来变得足够让人着迷,值得更彻底的实现。

符号代数的第一步是定义一个表示。将表达式转换为合适的形式实际上很简单;不需要解析表达式,因为我们有 Lua 为我们做这件事。使用 pl.func 库可以完成所有繁重的工作;它重新定义了算术运算,使其作用于占位符表达式 (PE),即包含名为占位符的虚拟变量的 Lua 表达式。pl.func 为名为 _1_2 等的参数定义了标准占位符,但 Var 函数将创建我们选择的新的占位符

utils.import 'pl.func'
a,b,c,d = Var 'a,b,c,d'
print(a+b+c+d)

这将确实以可读的形式打印出表达式。PE 运算符表达式存储为类似 {op='+',x,y} 的表的组合,这些表有一个关联的元表,该元表定义了元方法,如 __add 等。作为一棵树,使用 Lua 运算符的通常结合性,我们得到

绘制这些图很烦人,因此更好的表示法是 Lisp 风格的 S 表达式

1: (+ (+ (+ a b) c) d)

但是,通过我们将执行的各种操作,这种规范形式并不是 a+b+c+d 的唯一可能表示。

2: (+ a (+ b (+ c d)))
3: (+ (+ a b) (+ c d))

现在,经验表明这会导致疯狂。相反,更容易采用规范的 Lisp 表示

4: (+ a b c d)

一旦到位,许多操作变得很简单,例如与 (+ a c b d) 比较,只需对参数进行“无序比较”即可。以这种形式显示 PE 很简单。isPE 只需检查表达式以查看它是否为占位符表达式,方法是查看元表。具有 op=='X' 的 PE 是占位符变量,因此其余部分必须是表达式节点。

function sexpr (e)
	if isPE(e) then
  	  if e.op ~= 'X' then
	    local args = tablex.imap(sexpr,e)
	    return '('..e.op..' '..table.concat(args,' ')..')'
	  else
	    return e.repr
	  end
	else
	  return tostring(e)
	end
end

第一个任务是平衡表达式,它将表示 1-3 转换为 4。

function balance (e)
	if isPE(e) and e.op ~= 'X' then
	  local op,args = e.op
	  if op == '+' or op == '*' then
		args = rcollect(e)
	  else
		args = imap(balance,e)
	  end
	  for i = 1,#args do
		e[i] = args[i]
	  end
	end
	return e
end

对于非交换运算符,想法只是通过在 PE 的数组部分(即参数列表)上映射 balance 来平衡所有子表达式。然后将它们原样复制回来。非平凡的部分是处理 + 和 *,其中有必要从类似 1、2 或 3 的表达式树中收集所有参数,并将它们转换为第四种形式。

function tcollect (op,e,ls)
    if isPE(e) and e.op == op then
	for i = 1,#e do
	   tcollect(op,e[i],ls)
        end
    else
        ls:append(e)
        return
    end
end

function rcollect (e)
    local res = List()
    tcollect(e.op,e,res)
    return res
end

这递归地向下遍历相同运算符链(前面提到的 (+ (+ ...))并收集参数,将它们展平成 n 元 + 或 * 表达式。

这是一个有用的函数,它遵循相同的递归模式。

-- does this PE contain a reference to x?
function references (e,x)
    if isPE(e) then
		if e.op == 'X' then return x.repr == e.repr
		else
			return find_if(e,references,x)
		end
	else
		return false
	end
end

以下是创建 n 元积和和的函数。

function muli (args) return PE{op='*',unpack(args)} end
function addi (args) return PE{op='+',unpack(args)} end

有了这些,基本的微分规则就不难了。首先,只考虑包含变量的子表达式。

function diff (e,x)
    if isPE(e) and references(e,x) then
		local op = e.op
        if op == 'X' then
            return 1
	else
   	    local a,b = e[1],e[2]
            if op == '+' then -- differentiation is linear
		local args = imap(diff,e,x)
		return balance(addi(args))
            elseif op == '*' then -- product rule
		local res,d,ee = {}
		for i = 1,#e do
			d = fold(diff(e[i],x))
			if d ~= 0 then
	 		  ee = {unpack(e)} -- make a copy
			  ee[i] = d
			  append(res,balance(muli(ee)))
			end
		end
		if #res > 1 then return addi(res)
		else return res[1] end
            elseif op == '^' and isnumber(b) then -- power rule
                return b*x^(b-1)
            end
        end
	else
		return 0
	end
end

表达式之和的导数是导数之和。同样,imap 在子表达式上递归地应用函数。在构建结果后,我们重新平衡以求好运。

这里给出了乘积规则的一般形式,并明确检查了结果为零的项 - 这是 fold 的工作,我们将在下一节讨论。

(uvw..)' = u'vw.. + uv'w... + uvw'... + ...

最后,是简单的幂规则。请注意,结果可以用直接的方式表示,因为所有这些运算符都作用于 PE。

事实上,如果你使用形式 1,二元 + 和 *,所有这些规则都更加清晰!但随后简化变得难以忍受。而简化(“折叠”)是最难做对的。fold 是一个很长的函数,所以我将分段处理它。

local op = e.op
local addmul = op == '*' or op == '+'
-- first fold all arguments
local args = imap(fold,e)
if not addmul and not find_if(args,isPE) then
  -- no placeholders in these args, we can fold the expression.
  local opfn = optable[op]
  if opfn then
    return opfn(unpack(args))
  else
   return '?'
  end
elseif addmul then

第一个 if 正在寻找子表达式没有符号的情况,即它类似于 2*510^2;在这种情况下,常数可以完全折叠。optable(在 pl.operator 中定义)给出了运算符名称和实现它们的实际函数之间的映射。

elseif op == '^' then
  if args[2] == 1 then return args[1] end -- identity
  if args[2] == 0 then return 1 end
end
return PE{op=op,unpack(args)}

此子句正在清除表达式,例如 x^1y^0,这些表达式自然地从 diff 中的幂规则产生。一旦 args 被处理,表达式就可以重新组合在一起。

此例程的大部分内容处理令人头疼的双胞胎,+ 和 *。

-- split the args into two classes, PE args and non-PE args.
local classes = List.partition(args,isPE)
local pe,npe = classes[true],classes[false]

List.partition 接受一个列表和一个函数,该函数接受一个参数并返回一个单一值。结果是一个表,其中键是返回的值,而值是函数返回该值的那些元素的列表。所以

List{1,2,3,4}:partition(function(x) return x > 2 end)
--> {false={1,2},true={3,4}}
List{'one',math.sin,10,20,{1,2}}:partition(type)
--> {function={function: 00369110},string={one},number={10,20},table={{{1,2}} }

(在数学上,这些被称为等价类,而 partition 将被称为商集

在这种情况下,我们希望将非符号参数与符号参数分开;顺序无关紧要。非符号参数 npe 可以折叠成一个常数。此时,运算符恒等式规则可以生效,因此我们可以删除 (* 0 x) 并将 (+ 0 x) 简化为 x

最终的简化是替换重复的值,以便 (* x x) 应该变为 (^ x 2),而 (+ x x x) 应该变为 (* x 3)pl.tablex 中的 count_map 将完成这项工作。它接受一个类似列表的表和一个定义等价性的函数,并返回一个从值到其出现次数的映射,以便 count_map{'a','b','a'} {a=2,b=1}

给定此测试函数

function testdiff (e)
  balance(e)
  e = diff(e,x)
  balance(e)
  print('+ ',e)
  e = fold(e)
  print('- ',e)
end

以及这些情况

testdiff(x^2+1)
testdiff(3*x^2)
testdiff(x^2 + 2*x^3)
testdiff(x^2 + 2*a*x^3 + x^4)
testdiff(2*a*x^3)
testdiff(x*x*x)

我们得到了这些结果,说明了为什么像fold这样的函数对于处理diff的结果是如此必要。

+ 	2 * x ^ 1 + 0
- 	2 * x
+ 	3 * 2 * x
- 	6 * x
+ 	2 * x ^ 1 + 2 * 3 * x ^ 2
- 	2 * x + 6 * x ^ 2
+ 	2 * x ^ 1 + 2 * a * 3 * x ^ 2 + 4 * x ^ 3
- 	6 * a * x ^ 2 + 4 * x ^ 3 + 2 * x
+ 	2 * a * 3 * x ^ 2
- 	6 * a * x ^ 2
+ 	1 * x * x + x * 1 * x + x * x * 1
- 	x ^ 2 * 3

https://github.com/stevedonovan/Penlight/blob/master/examples/symbols.lua

https://github.com/stevedonovan/Penlight/blob/master/examples/test-symbols.lua

SteveDonovan

参见


最近更改 · 偏好设置
编辑 · 历史记录
最后编辑于 2012 年 7 月 4 日上午 10:41 GMT (差异)