Skip to content

BUPT web search home work。基于SML算法的GMM模型。

License

Notifications You must be signed in to change notification settings

xtymbbt/WebSearchHW

Repository files navigation

基于SML算法的GMM高斯混合模型

介绍

本程序训练了4个语义类,建筑、树木、天空、公路。
每个语义类中有10张训练图像。
图像的高斯混合模型为4个分量,类的高斯混合模型为16个分量。
由于计算过程非常耗时,因此K-means和高斯混合模型的EM算法没有迭代太多次。

主要文件介绍:

test.py

该文件起初为测试function.py中各个函数所用,因此文件中会有许多注释掉的代码。 现在可作为对单个图片进行标注的文件。通过加载预训练好的模型参数,实现对单个图片 进行测试。

main.py

程序主函数。运行该函数即可开始运行程序。

functions.py

存放各种函数。

  • def sml():

该函数为各个函数的综合,按照SML算法的顺序依次执行。

  • def load_picture(path):

该函数用于从指定路径path中加载dataset中的图片, 然后将同一个语义文件夹中的所有图片综合到一个矩阵img中,并返回。

  • def split_picture(img):

该函数用于执行Web搜索课本中SML算法的第②-a步,即,通过一个像素数为8×8, 每次滑动2个像素的窗口,将img矩阵中的每一张图片分解为相互重叠的N个区域。 一共(图像宽度-6)/2 × (图像高度)-6)/2 个区域,每个区域像素数为8×8。 最终返回一个大小为(图片数m,区域数n,8,8,通道数3)的矩阵。

  • def split_picture_single_image(img):

该函数作用与上一个函数作用相似,不同之处在于,此函数是用于给单张图片贴标签用的。 因此,该函数中只计算单张图片的切分。

  • def dct(spl_img):

该函数用于执行Web搜索课本中SML算法的第②-b步。即,分别计算n个区域的DCT变换, 然后将图片进行压缩。本程序中只取了每个区域DCT变换后所得到的图像的左上角的区域, 该区域是DCT变换之后大部分能量集中的区域。 最后,将三个通道进行交错排列。
:本函数对原书中的顺序稍微做了一点小改动。 原书中是先进行交错排列,后进行数据压缩。 而本函数适应了numpy方便的矩阵运算,因此,先进行了数据压缩,后进行了交错排列。

  • def k_means(shuffle_fl):

该函数为单张图片的K-means算法。
该函数对每一张图片中的N个区域进行K-means聚类,得到GMM模型的初始化均值。

  • def gaussian(x, mu, sigma):

高斯函数的计算方程式。
输入为x,x的均值mu,x的协方差矩阵sigma。
输出为高斯分布的概率值。

  • def single_em(img, mu_all):

通过EM算法进行高斯混合分布的估计。
需要注意的是,需要在其中预先计算出N个区域的均值向量以及协方差矩阵, 对于协方差矩阵,不能随机初始化协方差矩阵,根据均值和N个区域的值, 很容易可以计算出协方差矩阵。如果随机初始化的话, 那么将会造成高斯分布计算出错的问题,从而影响后续进程。
输入:
img:语义w中的每一张图片,图片已经过DCT变换,已经过数据压缩,已经过交错排列。
mu_all:语义w中的每一张图片被切分成的N个区域,经过扁平化之后的向量,经过K-means算法之后初始化得到的均值。

  • def k_means_ex_em(img_mu):

用于扩展EM算法的K-means算法。即用于类模型的的K-means算法。
对通过EM算法之后得到的高斯混合分布模型的均值进行K-means聚类, 从而得到扩展EM算法的初始化均值。

  • def ex_em(img_mu, img_sigma, img_pi, class_mu):

扩展EM算法。即用于类模型的EM算法。
均值向量及协方差矩阵同single_em。需要预先计算出协方差矩阵的值。 输入:
img_mu:每张图片的混合高斯分布的均值。
img_sigma:每张图片的混合高斯分布的协方差矩阵。
img_pi:每张图片中4个混合高斯分布所占的权重。
class_mu:通过扩展EM算法得到的混合高斯分布的均值。

  • label(path, w1, w2, w3, w4):

对要进行标注的图片(在path路径下)进行标注。
w1,w2,w3,w4为四个语义类的模型信息,为Python的dictionary类型, 其中包含有每个类模型的均值、协方差矩阵、权重。
通过利用贝叶斯后验概率准则对该图片所属的语义类进行估计。
最后的返回值是该图片的所属的语义类。

收敛后的模型:

  • 收敛后的模型存放位置在classModel文件夹中,文件格式为.csv
    值得注意的是,对于class_sigma而言,由于是16个协方差矩阵,其维度是三维的, 但是pandas中的data frame只能存储二维数据,因此,需要将其由三维转为二维, 通过numpy中自带的reshape来将后两维重新排列成一维。
    因此,在读取该文件时,对于class_sigma,需要将.csv文件中所存储的第二维重新 变为两个维度,即,总共变为三个维度,后两个维度的具体大小为第二个维度的大小开方。

w1:用于存放语义为w1的模型数据,其中包括均值、协方差、权重
w2:用于存放语义为w2的模型数据,其中包括均值、协方差、权重
w3:用于存放语义为w3的模型数据,其中包括均值、协方差、权重
w4:用于存放语义为w4的模型数据,其中包括均值、协方差、权重

About

BUPT web search home work。基于SML算法的GMM模型。

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages