符号微分 |
|
符号代数的第一步是定义一个表示。将表达式转换为合适的形式实际上很简单;不需要解析表达式,因为我们有 Lua 为我们做这件事。使用 pl.func
库可以完成所有繁重的工作;它重新定义了算术运算,使其作用于占位符表达式 (PE),即包含名为占位符的虚拟变量的 Lua 表达式。pl.func
为名为 _1
、_2
等的参数定义了标准占位符,但 Var
函数将创建我们选择的新的占位符
utils.import 'pl.func' a,b,c,d = Var 'a,b,c,d' print(a+b+c+d)
这将确实以可读的形式打印出表达式。PE 运算符表达式存储为类似 {op='+',x,y}
的表的组合,这些表有一个关联的元表,该元表定义了元方法,如 __add
等。作为一棵树,使用 Lua 运算符的通常结合性,我们得到
绘制这些图很烦人,因此更好的表示法是 Lisp 风格的 S 表达式
1: (+ (+ (+ a b) c) d)
但是,通过我们将执行的各种操作,这种规范形式并不是 a+b+c+d
的唯一可能表示。
2: (+ a (+ b (+ c d))) 3: (+ (+ a b) (+ c d))
现在,经验表明这会导致疯狂。相反,更容易采用规范的 Lisp 表示
4: (+ a b c d)
一旦到位,许多操作变得很简单,例如与 (+ a c b d)
比较,只需对参数进行“无序比较”即可。以这种形式显示 PE 很简单。isPE
只需检查表达式以查看它是否为占位符表达式,方法是查看元表。具有 op=='X'
的 PE 是占位符变量,因此其余部分必须是表达式节点。
function sexpr (e) if isPE(e) then if e.op ~= 'X' then local args = tablex.imap(sexpr,e) return '('..e.op..' '..table.concat(args,' ')..')' else return e.repr end else return tostring(e) end end
第一个任务是平衡表达式,它将表示 1-3 转换为 4。
function balance (e) if isPE(e) and e.op ~= 'X' then local op,args = e.op if op == '+' or op == '*' then args = rcollect(e) else args = imap(balance,e) end for i = 1,#args do e[i] = args[i] end end return e end
对于非交换运算符,想法只是通过在 PE 的数组部分(即参数列表)上映射 balance
来平衡所有子表达式。然后将它们原样复制回来。非平凡的部分是处理 + 和 *,其中有必要从类似 1、2 或 3 的表达式树中收集所有参数,并将它们转换为第四种形式。
function tcollect (op,e,ls) if isPE(e) and e.op == op then for i = 1,#e do tcollect(op,e[i],ls) end else ls:append(e) return end end function rcollect (e) local res = List() tcollect(e.op,e,res) return res end
这递归地向下遍历相同运算符链(前面提到的 (+ (+ ...)
)并收集参数,将它们展平成 n 元 + 或 * 表达式。
这是一个有用的函数,它遵循相同的递归模式。
-- does this PE contain a reference to x? function references (e,x) if isPE(e) then if e.op == 'X' then return x.repr == e.repr else return find_if(e,references,x) end else return false end end
以下是创建 n 元积和和的函数。
function muli (args) return PE{op='*',unpack(args)} end function addi (args) return PE{op='+',unpack(args)} end
有了这些,基本的微分规则就不难了。首先,只考虑包含变量的子表达式。
function diff (e,x) if isPE(e) and references(e,x) then local op = e.op if op == 'X' then return 1 else local a,b = e[1],e[2] if op == '+' then -- differentiation is linear local args = imap(diff,e,x) return balance(addi(args)) elseif op == '*' then -- product rule local res,d,ee = {} for i = 1,#e do d = fold(diff(e[i],x)) if d ~= 0 then ee = {unpack(e)} -- make a copy ee[i] = d append(res,balance(muli(ee))) end end if #res > 1 then return addi(res) else return res[1] end elseif op == '^' and isnumber(b) then -- power rule return b*x^(b-1) end end else return 0 end end
表达式之和的导数是导数之和。同样,imap
在子表达式上递归地应用函数。在构建结果后,我们重新平衡以求好运。
这里给出了乘积规则的一般形式,并明确检查了结果为零的项 - 这是 fold
的工作,我们将在下一节讨论。
(uvw..)' = u'vw.. + uv'w... + uvw'... + ...
最后,是简单的幂规则。请注意,结果可以用直接的方式表示,因为所有这些运算符都作用于 PE。
事实上,如果你使用形式 1,二元 + 和 *,所有这些规则都更加清晰!但随后简化变得难以忍受。而简化(“折叠”)是最难做对的。fold
是一个很长的函数,所以我将分段处理它。
local op = e.op local addmul = op == '*' or op == '+' -- first fold all arguments local args = imap(fold,e) if not addmul and not find_if(args,isPE) then -- no placeholders in these args, we can fold the expression. local opfn = optable[op] if opfn then return opfn(unpack(args)) else return '?' end elseif addmul then
第一个 if
正在寻找子表达式没有符号的情况,即它类似于 2*5
或 10^2
;在这种情况下,常数可以完全折叠。optable
(在 pl.operator
中定义)给出了运算符名称和实现它们的实际函数之间的映射。
elseif op == '^' then if args[2] == 1 then return args[1] end -- identity if args[2] == 0 then return 1 end end return PE{op=op,unpack(args)}
此子句正在清除表达式,例如 x^1
和 y^0
,这些表达式自然地从 diff
中的幂规则产生。一旦 args
被处理,表达式就可以重新组合在一起。
此例程的大部分内容处理令人头疼的双胞胎,+ 和 *。
-- split the args into two classes, PE args and non-PE args. local classes = List.partition(args,isPE) local pe,npe = classes[true],classes[false]
List.partition
接受一个列表和一个函数,该函数接受一个参数并返回一个单一值。结果是一个表,其中键是返回的值,而值是函数返回该值的那些元素的列表。所以
List{1,2,3,4}:partition(function(x) return x > 2 end) --> {false={1,2},true={3,4}} List{'one',math.sin,10,20,{1,2}}:partition(type) --> {function={function: 00369110},string={one},number={10,20},table={{{1,2}} }
(在数学上,这些被称为等价类,而 partition
将被称为商集)
在这种情况下,我们希望将非符号参数与符号参数分开;顺序无关紧要。非符号参数 npe
可以折叠成一个常数。此时,运算符恒等式规则可以生效,因此我们可以删除 (* 0 x)
并将 (+ 0 x)
简化为 x
。
最终的简化是替换重复的值,以便 (* x x)
应该变为 (^ x 2)
,而 (+ x x x)
应该变为 (* x 3)
。pl.tablex
中的 count_map
将完成这项工作。它接受一个类似列表的表和一个定义等价性的函数,并返回一个从值到其出现次数的映射,以便 count_map{'a','b','a'}
为 {a=2,b=1}
。
给定此测试函数
function testdiff (e) balance(e) e = diff(e,x) balance(e) print('+ ',e) e = fold(e) print('- ',e) end
以及这些情况
testdiff(x^2+1) testdiff(3*x^2) testdiff(x^2 + 2*x^3) testdiff(x^2 + 2*a*x^3 + x^4) testdiff(2*a*x^3) testdiff(x*x*x)
我们得到了这些结果,说明了为什么像fold
这样的函数对于处理diff
的结果是如此必要。
+ 2 * x ^ 1 + 0 - 2 * x + 3 * 2 * x - 6 * x + 2 * x ^ 1 + 2 * 3 * x ^ 2 - 2 * x + 6 * x ^ 2 + 2 * x ^ 1 + 2 * a * 3 * x ^ 2 + 4 * x ^ 3 - 6 * a * x ^ 2 + 4 * x ^ 3 + 2 * x + 2 * a * 3 * x ^ 2 - 6 * a * x ^ 2 + 1 * x * x + x * 1 * x + x * x * 1 - x ^ 2 * 3
https://github.com/stevedonovan/Penlight/blob/master/examples/symbols.lua
https://github.com/stevedonovan/Penlight/blob/master/examples/test-symbols.lua