背景
在计算机求和的过程中,一个大数和小数的相加会因为浮点数的有限精度,而导致截断误差的出现。所以在构建计算网格的时候,都要极力避免这样情形的发生,将计算统一在相对较近的数量级上。所以,当需要对一系列的数值做加法时,一个好的技巧是将这些数由大到小做排列,再逐个相加。
而如果一定要做出这样的大数与小数的求和,一个直观想法就是:大数部分和小数部分的高位相加,将剩余的小数部分作为单独的“补全”部分相加。这种直观想法的官方名称叫做__Kahan求和法__。
假设当前的浮点数变量可以保存6位
的数值。那么,数值12345
与1.234
相加的理论值应该是12346.234
。但由于当前只能保存6位数值,这个正确的理论值会被截断为12346.2,这就出现了0.034的误差。当有很多这样的大数与小数相加时,截断误差就会逐步累积,导致最后的计算结果出现大的偏差。
Kahan算法:
先上伪代码(看不懂不要着急,后面要逐一解释):
def KahanSum(input):
var sum = 0.0
var c = 0.0
for i = 1 to input.length do
var y = input[i] + c # 在新的输入中,加上之前累积的补全部分.
var t = sum + y # 此时,将这个新的输入加入到求和变量sum里时,会造成低位部分被截断
c = y - (t - sum) # 恢复上一步求和时被截断的低位部分,累积进入补全变量
sum = t
next i
return sum
在上述伪代码中,变量c
表示的即是小数的补全部分compensation,更严格地说,应该是__负的__补全部分。随着这个补全部分的不断积累,当这些截断误差积累到一定量级,它们在求和的时候也就不会被截断了,从而能够相对好地控制整个求和过程的精度。
以下,先用一个具体的理论例子来说明。比如,用
$$10000.0 + \pi + \mathcal{e}$$
来说明。
我们依旧假设浮点型变量只能保存6位数值。此时,具体写出求和算式应该是:
$$10000.0 + 3.14159 + 2.71828$$
它们的理论结果应该是10005.85987
,约等于10005.9
。
但由于截断误差,第一次求和
$$10000.0 + 3.14159$$
只能得到结果10003.1
;这个结果再与2.71828
相加,得到10005.81828
,被截断为10005.8
。此时结果就相差了0.1
。
运用Kahan求和法,我们的运行过程是(记住,我们的浮点型变量__保存6位数值__),
第一次求和:
y = 3.14159 + 0.00000
t = 10000.0 + 3.14159
= 10003.14159
= 10003.1 # 低位部分被截断,丢失
c = 3.14159 - (10003.1 - 10000.0)
= 3.14159 - 3.10000
= (.0415900) # 恢复丢失的截断部分
sum = 1003.1
第二次求和:
y = 2.71828 + (.0415900)
= 2.75985 # 将上一步的补全部分增加在新的输入上
t = 10003.1 + 2.75987
= 10005.85987
= 10005.9 # 当低位部分被累积得足够大后,它就不会被大数的求和所截断消失
c = 2.75987 - (10005.9 - 10003.1)
= 2.75987 - 2.80000
= -.040130
sum = 10005.9
实例
以上是理论分析。下面用一个可以运行的Python代码做示范,方便感兴趣的朋友做研究。
这个例子曾经出现于Google的首席科学家_Vincent Vanhoucke_在Udacity上开设的__Deep Learning__课程。
这个求和算式是:在$10^9$的基础上,加上$10^{-6}$,总共重复$10^6$次这个加法,再减去$10^9$,即
$$10^9 + 10^6*10^{-6} - 10^9$$
理论值显然应该为1
。
summ = 10**9
for indx in range(10**6):
summ += 10**(-6)
summ -= 10**9
print(summ)
# output: 0.95367431640625
运行后,可以貌似惊讶地看到结果竟然不是1
,而是0.95367431640625
!
这可以说明,在$10^6$次求和后,截断误差的累积量已经非常可观了。
运用Kahan算法做改进
如果我们用Kahan求和法来做改进,可以得到:
summ = 10**9
c = 0.0
for indx in range(10**6):
y = 10**(-6) + c
t = summ + y
c = y - (t - summ)
summ = t
summ -= 10**9
print(summ)
# output: 1.0
运行后,我们可以欣喜地看到正确结果:1.0。
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。