张春成
2022/10/19阅读:30主题:默认主题
交叉熵的校正
交叉熵的校正
交叉熵是个好东西,只要它不崩溃。
交叉熵的不稳定性
交叉熵易于计算,并且它的极小值是信息熵,如果我们想让某个信号提供的信息量尽可能的少,就可以通过压缩它来实现。
在机器学习中,如果我们已知、或者说能够观测到某个信号,那么,让另一个信号提供的信息量尽可能少,就意味着另一个信号更容易预测。这是很有用的特性。
然而,在某些情况下(当然,这种情况并不特殊,而是广泛存在的情况),这个上限过于宽松,以至于不能给出任何合理的优化方向。
如果你的微积分还没有完全放下的话,就能够想到,这种情况很有可能出现在函数的无穷间断点处。回顾交叉熵的定义,容易发现两个概率在不同的点取零时,函数会遇到一个无穷间断点。
这种情况在 Shannon Entropy 中不会遇到,是因为
然而在交叉熵函数中,这种情况就是一个严重的问题。因为函数在这些地方不可微。这种问题有多严重呢?它们可能颠覆计算结果。举个例子,我们用之前的方法构造一组随机信号,它们最大和最小熵的信号如下图所示,可见,其中最小熵的信号的熵是 3.33。
.png)
而我们对信号两两之间做交叉熵,会得到一个十分诡异的结果,那就是交叉熵的最小值并非是熵最小的信号,而是另有其人。这是由于在计算中“舍弃”了无穷间断点,这样做是由于迫不得已,因为计算方法无法处理无穷这种东西。
.png)
简单校正
这显然是不对的,因此需要进行校正。校正过程的思路是借鉴贝叶斯学习的思路,认为信号本身有其本征分布,我们“观测”到信号的行为只是对本征分布进行修正,而永远不会改变它的本质
简单来说,就是将全部概率值都加上一个常数,再将修正后的概率进行归一化即可。
经过校正后得到的交叉熵如下,虽然这种估计增加了整体的不确定性,但没有改变信号之间熵的排序 关系,至少它能够正常找到最小的交叉熵信号,并且避免了无穷间断点问题。
.png)
在校正前、后,通过交叉熵寻找与目标信号最“像”的信号的工作也由于校正的存在而变得更加正常
.png)
校正前
.png)
校正后
校正代码
校正代码如下,它基本上已经算是自动熵计算的可用工具了,就是还有点慢。
class DataWithProb(object):
'''
Data with Probability
Every entropy you need for a dataset.
'''
def __init__(self):
pass
def load(self, data, computes=['prob', 'joint_prob']):
'''
Load new data, the data is 2-d array,
The 1st dimension is the signal.
The function will compute things automatically
- prob.: The empirical probability of every signal,
in 1-d array
- joint prob.: The empirical joint probability of every two signals,
in 2-d array
Since the data is inherently discrete, the discrete prob is computed.
:param:data: The data will be computed
'''
self.data = data
self.shape = data.shape
print('Data shape is {}'.format(self.shape))
bins = self.auto_bins()
if 'prob' in computes:
self.compute_prob(bins)
if 'joint_prob' in computes:
self.compute_joint_prob(bins)
return
def auto_bins(self, num=100):
'''
Compute bins by default setting.
:param: num: A number of how many bins we will use
:return: bins: Hhe linear seperated bins
'''
data = self.data
bins = np.linspace(np.min(data), np.max(data), num)
return bins
def compute_prob(self, bins):
'''
Compute prob for every signal
:param: bins: The bins
:return: prob: The prob for every signal, the shape is n x (m-1),
where m is the count of bins, n is the count of signal
'''
data = self.data
n = data.shape[0]
m = bins.shape[0]
prob = np.zeros((n, m-1))
j = 0
for d in tqdm(data, 'Prob.'):
a, b = np.histogram(d, bins=bins)
a = a.astype(np.float32)
a /= np.sum(a)
prob[j] = a
j += 1
pass
self.prob = prob
self.prob_bins = bins
return prob
def compute_joint_prob(self, bins):
'''
Compute joint prob for every two signal
:param: bins: The bins
:return: joint_prob: The joint_prob matrix of every two signal,
the shape is n x n x (m-1) x (m-1).
The first two n refer the signal pair;
the last two (m-1) refer the bins grid.
'''
data = self.data
n = data.shape[0]
m = bins.shape[0]
joint_prob = np.zeros((n, n, m-1, m-1))
for j in tqdm(range(n), 'Joint prob.'):
for k in range(n):
a, b, c = np.histogram2d(data[j], data[k], bins=bins)
a = a.astype(np.float32)
a /= np.sum(a)
joint_prob[j][k] = a
pass
self.joint_prob = joint_prob
self.joint_prob_bins = bins
return joint_prob
def shannon_entropy(self):
'''
Compute shannon entropy for every signal
:return: entropy: The shannon entropy for every signal,
it is a 1-d array
'''
entropy = np.array([scipy.stats.entropy(p) for p in tqdm(self.prob, 'Shannon Envropy')])
return entropy
def joint_entropy(self):
'''
Compute joint entropy for every two signals
:return: joint_entropy: The joint entropy for every two signal,
it is a 2-d array
'''
joint_prob = self.joint_prob
n = joint_prob.shape[0]
joint_entropy = np.zeros((n, n))
for j in tqdm(range(n), 'Joint Entropy'):
for k in range(n):
p = joint_prob[j][k].flatten()
e = scipy.stats.entropy(p)
joint_entropy[j][k] = e
pass
return joint_entropy
def mutual_information(self):
'''
Compute mutual information for every two signals
:return: mutual_information: The mutual information for every two signals,
it is a 2-d array
'''
prob = self.prob
joint_prob = self.joint_prob
n = joint_prob.shape[0]
mutual_information = np.zeros((n, n))
for j in tqdm(range(n), 'Mutual Information'):
for k in range(n):
p1 = prob[j][:, np.newaxis]
p2 = prob[k][np.newaxis, :]
pxy = np.matmul(p1, p2)
pp = joint_prob[j, k]
m = pp != 0
pxy = pxy[m]
pp = pp[m]
s = pp * np.log(pxy / pp)
mutual_information[j][k] = -np.sum(s)
return mutual_information
def cross_entropy(self):
'''
Compute cross entropy for every two signals
:return: cross_entropy: The cross entropy for every two signals,
it is a 2-d array
'''
prob = self.prob.copy()
n = prob.shape[0]
cross_entropy = np.zeros((n, n))
for j in tqdm(range(n), 'Cross Entropy'):
for k in range(n):
p1 = prob[j]
p2 = prob[k]
m = p2 != 0
p1 = p1[m]
p2 = p2[m]
s = p1 * np.log(p2)
cross_entropy[j][k] = -np.sum(s)
return cross_entropy
def cross_entropy_correct(self):
'''
Compute cross entropy for every two signals
:return: cross_entropy: The cross entropy for every two signals,
it is a 2-d array
'''
prob = self.prob.copy()
# Correct Prob
for p in prob:
p += 0.01
p /= np.sum(p)
n = prob.shape[0]
cross_entropy = np.zeros((n, n))
for j in tqdm(range(n), 'Cross Entropy Correct'):
for k in range(n):
p1 = prob[j]
p2 = prob[k]
m = p2 != 0
p1 = p1[m]
p2 = p2[m]
s = p1 * np.log(p2)
cross_entropy[j][k] = -np.sum(s)
return cross_entropy
作者介绍