符号微分 |
|
符号代数的第一步是定义一种表示方法。将表达式转换为合适的格式实际上很简单;不需要解析表达式,因为Lua可以为我们完成这项工作。使用pl.func库可以完成所有繁重的工作;它重新定义了算术运算,使其作用于*占位符表达式*(PE),这些表达式是包含称为占位符的虚拟变量的Lua表达式。pl.func定义了名为_1、_2等的标准占位符参数,但Var函数将创建我们选择的新占位符。
utils.import 'pl.func' a,b,c,d = Var 'a,b,c,d' print(a+b+c+d)
这将以可读的形式打印出表达式。PE运算符表达式存储为类似于 {op='+',x,y} 的表的组合,这些表具有关联的元表,定义了诸如__add等元方法。作为一个树,具有Lua运算符的通常的结合律,我们得到
绘制这些图表很麻烦,因此更好的表示法是Lisp风格的S表达式。
1: (+ (+ (+ a b) c) d)
然而,通过我们将执行的各种操作,这种规范形式不是a+b+c+d的唯一可能表示。
2: (+ a (+ b (+ c d))) 3: (+ (+ a b) (+ c d))
现在,经验表明这会导致疯狂。相反,最好采用规范的Lisp表示法。
4: (+ a b c d)
一旦实现了这一点,许多操作就变得很简单,例如与(+ a c b d)进行比较,只需要对参数进行“无序比较”。以这种形式显示PE很简单。isPE通过查看元表来简单地检查表达式是否为占位符表达式。操作符为'X'的PE是占位符变量,其余的必须是表达式节点。
function sexpr (e) if isPE(e) then if e.op ~= 'X' then local args = tablex.imap(sexpr,e) return '('..e.op..' '..table.concat(args,' ')..')' else return e.repr end else return tostring(e) end end
第一个任务是*平衡*表达式,它将表示1-3转换为4。
function balance (e) if isPE(e) and e.op ~= 'X' then local op,args = e.op if op == '+' or op == '*' then args = rcollect(e) else args = imap(balance,e) end for i = 1,#args do e[i] = args[i] end end return e end
对于非交换运算符,思想就是通过将balance映射到PE的数组部分(即参数列表)来平衡所有子表达式。然后将它们复制回原处。棘手的部分是处理+和*,因为有必要收集所有来自表达式树(看起来像1、2或3)的参数,并将它们转换为第四种形式。
function tcollect (op,e,ls) if isPE(e) and e.op == op then for i = 1,#e do tcollect(op,e[i],ls) end else ls:append(e) return end end function rcollect (e) local res = List() tcollect(e.op,e,res) return res end
这会递归地遍历相同的运算符链(前面提到的(+ (+ ...)),并收集参数,将它们展平成n元+或*表达式。
这是一个有用的函数,它遵循相同的递归模式。
-- does this PE contain a reference to x? function references (e,x) if isPE(e) then if e.op == 'X' then return x.repr == e.repr else return find_if(e,references,x) end else return false end end
这里是创建n元乘积和和的函数。
function muli (args) return PE{op='*',unpack(args)} end function addi (args) return PE{op='+',unpack(args)} end
有了这个,基本的微分规则就不难了。首先,只考虑包含该变量的子表达式。
function diff (e,x) if isPE(e) and references(e,x) then local op = e.op if op == 'X' then return 1 else local a,b = e[1],e[2] if op == '+' then -- differentiation is linear local args = imap(diff,e,x) return balance(addi(args)) elseif op == '*' then -- product rule local res,d,ee = {} for i = 1,#e do d = fold(diff(e[i],x)) if d ~= 0 then ee = {unpack(e)} -- make a copy ee[i] = d append(res,balance(muli(ee))) end end if #res > 1 then return addi(res) else return res[1] end elseif op == '^' and isnumber(b) then -- power rule return b*x^(b-1) end end else return 0 end end
表达式的和的导数是导数的和。同样,imap负责将函数递归地应用于子表达式。构造结果后,我们为了方便再次进行平衡。
乘积规则在这里以其一般形式给出,并显式检查那些结果为零的项——这是fold的工作,稍后将讨论它。
(uvw..)' = u'vw.. + uv'w... + uvw'... + ...
最后是简单的幂法则。注意结果是如何以简单明了的方式表示的,因为所有这些运算符都在PE上操作。
事实上,如果您使用形式1,即二元+和*,那么所有这些规则肯定会更清晰!但然后简化将变得无法忍受。而简化(“折叠”)是棘手的,需要正确处理。fold是一个较长的函数,所以我会分段处理它。
local op = e.op local addmul = op == '*' or op == '+' -- first fold all arguments local args = imap(fold,e) if not addmul and not find_if(args,isPE) then -- no placeholders in these args, we can fold the expression. local opfn = optable[op] if opfn then return opfn(unpack(args)) else return '?' end elseif addmul then
第一个if正在寻找一个子表达式没有符号的情况,即它像2*5或10^2;在这种情况下,常数可以完全折叠。optable(在pl.operator中定义)提供了运算符名称与实现它们的实际函数之间的映射。
elseif op == '^' then if args[2] == 1 then return args[1] end -- identity if args[2] == 0 then return 1 end end return PE{op=op,unpack(args)}
这个子句正在清理像x^1和y^0这样的表达式,这些表达式自然地来自diff中的幂法则。一旦args被处理完毕,就可以重新组合表达式。
这个例程的大部分内容处理棘手的配对,+和*。
-- split the args into two classes, PE args and non-PE args. local classes = List.partition(args,isPE) local pe,npe = classes[true],classes[false]
List.partition接受一个列表和一个函数,该函数接受一个参数并返回一个单一值。结果是一个表,其中键是返回的值,值是函数返回该值的元素的列表。所以
List{1,2,3,4}:partition(function(x) return x > 2 end)
--> {false={1,2},true={3,4}}
List{'one',math.sin,10,20,{1,2}}:partition(type)
--> {function={function: 00369110},string={one},number={10,20},table={{{1,2}} }
(在数学上,这些被称为*等价类*,而partition被称为*商集*)
在这种情况下,我们想将非符号参数与符号参数分开;顺序无关紧要。非符号参数npe可以折叠成一个常数。此时,运算符身份规则可以生效,这样我们就可以删除(* 0 x)并将(+ 0 x)简化为x。
最后的简化是替换重复的值,这样(* x x)应该变成(^ x 2),而(+ x x x)应该变成(* x 3)。pl.tablex中的count_map可以完成这项工作。它接受一个类似列表的表和一个定义等价性的函数,并返回一个从值到其出现次数的映射,因此 count_map{'a','b','a'} 是 {a=2,b=1} 。
给定这个测试函数
function testdiff (e) balance(e) e = diff(e,x) balance(e) print('+ ',e) e = fold(e) print('- ',e) end
以及这些情况
testdiff(x^2+1) testdiff(3*x^2) testdiff(x^2 + 2*x^3) testdiff(x^2 + 2*a*x^3 + x^4) testdiff(2*a*x^3) testdiff(x*x*x)
我们得到这些结果,这表明为什么像fold这样的东西对于处理diff的结果如此必要。
+ 2 * x ^ 1 + 0 - 2 * x + 3 * 2 * x - 6 * x + 2 * x ^ 1 + 2 * 3 * x ^ 2 - 2 * x + 6 * x ^ 2 + 2 * x ^ 1 + 2 * a * 3 * x ^ 2 + 4 * x ^ 3 - 6 * a * x ^ 2 + 4 * x ^ 3 + 2 * x + 2 * a * 3 * x ^ 2 - 6 * a * x ^ 2 + 1 * x * x + x * 1 * x + x * x * 1 - x ^ 2 * 3
https://github.com/stevedonovan/Penlight/blob/master/examples/symbols.lua
https://github.com/stevedonovan/Penlight/blob/master/examples/test-symbols.lua