张春成

V2

2022/10/19阅读:30主题:默认主题

交叉熵的校正

交叉熵的校正

交叉熵是个好东西,只要它不崩溃。


交叉熵的不稳定性

交叉熵易于计算,并且它的极小值是信息熵,如果我们想让某个信号提供的信息量尽可能的少,就可以通过压缩它来实现。

在机器学习中,如果我们已知、或者说能够观测到某个信号,那么,让另一个信号提供的信息量尽可能少,就意味着另一个信号更容易预测。这是很有用的特性。

然而,在某些情况下(当然,这种情况并不特殊,而是广泛存在的情况),这个上限过于宽松,以至于不能给出任何合理的优化方向。

如果你的微积分还没有完全放下的话,就能够想到,这种情况很有可能出现在函数的无穷间断点处。回顾交叉熵的定义,容易发现两个概率在不同的点取零时,函数会遇到一个无穷间断点。

这种情况在 Shannon Entropy 中不会遇到,是因为

然而在交叉熵函数中,这种情况就是一个严重的问题。因为函数在这些地方不可微。这种问题有多严重呢?它们可能颠覆计算结果。举个例子,我们用之前的方法构造一组随机信号,它们最大和最小熵的信号如下图所示,可见,其中最小熵的信号的熵是 3.33。

newplot (25).png
newplot (25).png

而我们对信号两两之间做交叉熵,会得到一个十分诡异的结果,那就是交叉熵的最小值并非是熵最小的信号,而是另有其人。这是由于在计算中“舍弃”了无穷间断点,这样做是由于迫不得已,因为计算方法无法处理无穷这种东西。

newplot (26).png
newplot (26).png

简单校正

这显然是不对的,因此需要进行校正。校正过程的思路是借鉴贝叶斯学习的思路,认为信号本身有其本征分布,我们“观测”到信号的行为只是对本征分布进行修正,而永远不会改变它的本质

简单来说,就是将全部概率值都加上一个常数,再将修正后的概率进行归一化即可。

经过校正后得到的交叉熵如下,虽然这种估计增加了整体的不确定性,但没有改变信号之间熵的排序 关系,至少它能够正常找到最小的交叉熵信号,并且避免了无穷间断点问题。

newplot (27).png
newplot (27).png

在校正前、后,通过交叉熵寻找与目标信号最“像”的信号的工作也由于校正的存在而变得更加正常

校正前
校正前

校正前

校正后
校正后

校正后

校正代码

校正代码如下,它基本上已经算是自动熵计算的可用工具了,就是还有点慢。

class DataWithProb(object):
    '''
    Data with Probability

    Every entropy you need for a dataset.
    '''

    def __init__(self):
        pass

    def load(self, data, computes=['prob''joint_prob']):
        '''
        Load new data, the data is 2-d array,
        The 1st dimension is the signal.

        The function will compute things automatically
        - prob.: The empirical probability of every signal,
                 in 1-d array
        - joint prob.: The empirical joint probability of every two signals,
                       in 2-d array

        Since the data is inherently discrete, the discrete prob is computed.

        :param:data: The data will be computed
        '''

        self.data = data
        self.shape = data.shape
        print('Data shape is {}'.format(self.shape))
        bins = self.auto_bins()

        if 'prob' in computes:
            self.compute_prob(bins)

        if 'joint_prob' in computes:
            self.compute_joint_prob(bins)

        return

    def auto_bins(self, num=100):
        '''
        Compute bins by default setting.

        :param: num: A number of how many bins we will use

        :return: bins: Hhe linear seperated bins
        '''

        data = self.data
        bins = np.linspace(np.min(data), np.max(data), num)
        return bins

    def compute_prob(self, bins):
        '''
        Compute prob for every signal

        :param: bins: The bins
        :return: prob: The prob for every signal, the shape is n x (m-1),
                       where m is the count of bins, n is the count of signal
        '''

        data = self.data
        n = data.shape[0]
        m = bins.shape[0]

        prob = np.zeros((n, m-1))

        j = 0
        for d in tqdm(data, 'Prob.'):
            a, b = np.histogram(d, bins=bins)
            a = a.astype(np.float32)
            a /= np.sum(a)
            prob[j] = a
            j += 1
            pass

        self.prob = prob
        self.prob_bins = bins
        return prob

    def compute_joint_prob(self, bins):
        '''
        Compute joint prob for every two signal

        :param: bins: The bins
        :return: joint_prob: The joint_prob matrix of every two signal,
                             the shape is n x n x (m-1) x (m-1).
                             The first two n refer the signal pair;
                             the last two (m-1) refer the bins grid.
        '''

        data = self.data
        n = data.shape[0]
        m = bins.shape[0]

        joint_prob = np.zeros((n, n, m-1, m-1))

        for j in tqdm(range(n), 'Joint prob.'):
            for k in range(n):
                a, b, c = np.histogram2d(data[j], data[k], bins=bins)
                a = a.astype(np.float32)
                a /= np.sum(a)
                joint_prob[j][k] = a
                pass

        self.joint_prob = joint_prob
        self.joint_prob_bins = bins
        return joint_prob

    def shannon_entropy(self):
        '''
        Compute shannon entropy for every signal

        :return: entropy: The shannon entropy for every signal,
                          it is a 1-d array
        '''

        entropy = np.array([scipy.stats.entropy(p) for p in tqdm(self.prob, 'Shannon Envropy')])
        return entropy

    def joint_entropy(self):
        '''
        Compute joint entropy for every two signals

        :return: joint_entropy: The joint entropy for every two signal,
                                it is a 2-d array
        '''

        joint_prob = self.joint_prob
        n = joint_prob.shape[0]

        joint_entropy = np.zeros((n, n))

        for j in tqdm(range(n), 'Joint Entropy'):
            for k in range(n):
                p = joint_prob[j][k].flatten()
                e = scipy.stats.entropy(p)
                joint_entropy[j][k] = e
                pass

        return joint_entropy

    def mutual_information(self):
        '''
        Compute mutual information for every two signals

        :return: mutual_information: The mutual information for every two signals,
                                     it is a 2-d array
        '''

        prob = self.prob
        joint_prob = self.joint_prob
        n = joint_prob.shape[0]

        mutual_information = np.zeros((n, n))

        for j in tqdm(range(n), 'Mutual Information'):
            for k in range(n):
                p1 = prob[j][:, np.newaxis]
                p2 = prob[k][np.newaxis, :]
                pxy = np.matmul(p1, p2)
                pp = joint_prob[j, k]

                m = pp != 0
                pxy = pxy[m]
                pp = pp[m]

                s = pp * np.log(pxy / pp)
                mutual_information[j][k] = -np.sum(s)

        return mutual_information

    def cross_entropy(self):
        '''
        Compute cross entropy for every two signals

        :return: cross_entropy: The cross entropy for every two signals,
                                it is a 2-d array
        '''

        prob = self.prob.copy()

        n = prob.shape[0]

        cross_entropy = np.zeros((n, n))

        for j in tqdm(range(n), 'Cross Entropy'):
            for k in range(n):
                p1 = prob[j]
                p2 = prob[k]

                m = p2 != 0
                p1 = p1[m]
                p2 = p2[m]

                s = p1 * np.log(p2)
                cross_entropy[j][k] = -np.sum(s)

        return cross_entropy

    def cross_entropy_correct(self):
        '''
        Compute cross entropy for every two signals

        :return: cross_entropy: The cross entropy for every two signals,
                                it is a 2-d array
        '''

        prob = self.prob.copy()

        # Correct Prob
        for p in prob:
            p += 0.01
            p /= np.sum(p)

        n = prob.shape[0]

        cross_entropy = np.zeros((n, n))

        for j in tqdm(range(n), 'Cross Entropy Correct'):
            for k in range(n):
                p1 = prob[j]
                p2 = prob[k]

                m = p2 != 0
                p1 = p1[m]
                p2 = p2[m]

                s = p1 * np.log(p2)
                cross_entropy[j][k] = -np.sum(s)

        return cross_entropy

分类:

后端

标签:

后端

作者介绍

张春成
V2