【Kaldi】gmm-init-mono源码阅读

单音素GMM常出现在ASR系统冷启动训练阶段。

整体结构

在开始看复杂的细节前,不妨先看看整个代码的框架:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "gmm/am-diag-gmm.h"
#include "hmm/hmm-topology.h"
#include "hmm/transition-model.h"
#include "tree/context-dep.h"

namespace kaldi {
void ReadSharedPhonesList(std::string rxfilename, std::vector<std::vector<int32> > *list_out) {...}
}

int main(int argc, char *argv[]) {
...
}

可以看出这个源码结构非常简单,一些头文件引用,一个定义在命名空间kaldi中的函数ReadSharedPhonesList剩下就是主函数了。

ReadSharedPhonesList作用

下面我们来看看ReadSharedPhonesList的细节:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
namespace kaldi {
// This function reads a file like:
// 1 2 3
// 4 5
// 6 7 8
// where each line is a list of integer id's of phones (that should have their pdfs shared).
void ReadSharedPhonesList(std::string rxfilename, std::vector<std::vector<int32> > *list_out) {
list_out->clear();
Input input(rxfilename);
std::istream &is = input.Stream();
std::string line;
while (std::getline(is, line)) {
list_out->push_back(std::vector<int32>());
if (!SplitStringToIntegers(line, " \t\r", true, &(list_out->back())))
KALDI_ERR << "Bad line in shared phones list: " << line << " (reading "
<< PrintableRxfilename(rxfilename) << ")";
std::sort(list_out->rbegin()->begin(), list_out->rbegin()->end());
if (!IsSortedAndUniq(*(list_out->rbegin())))
KALDI_ERR << "Bad line in shared phones list (repeated phone): " << line
<< " (reading " << PrintableRxfilename(rxfilename) << ")";
}
}

} // end namespace kaldi

可以看出这个代码实际上就是把rxfilename中的数据读出来,然后组成一个二维数组,二维数组中每一行中的数字会共享pdf(概率密度函数)。

主函数

接下来进入主函数,为了对主函数的整体框架有一个大致的了解,我把复杂的代码用“…”表示,这些代码后续会讲到。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
int main(int argc, char *argv[]) {
try {
using namespace kaldi;
using kaldi::int32;

const char *usage =
"Initialize monophone GMM.\n"
"Usage:??gmm-init-mono <topology-in> <dim> <model-out> <tree-out> \n"
"e.g.: \n"
" gmm-init-mono topo 39 mono.mdl mono.tree\n";

bool binary = true;
std::string train_feats;
std::string shared_phones_rxfilename;
BaseFloat perturb_factor = 0.0;
ParseOptions po(usage);
po.Register("binary", &binary, "Write output in binary mode");
po.Register("train-feats", &train_feats,
"rspecifier for training features [used to set mean and variance]");
po.Register("shared-phones", &shared_phones_rxfilename,
"rxfilename containing, on each line, a list of phones whose pdfs should be shared.");
po.Register("perturb-factor", &perturb_factor,
"Perturb the means using this fraction of standard deviation.");
po.Read(argc, argv);

if (po.NumArgs() != 4) {
po.PrintUsage();
exit(1);
}

...

} catch(const std::exception &e) {
std::cerr << e.what();
return -1;
}
}

这一段关注ParseOptions,这个类有点类似Python中的argparse,不仅可以保存用法,还可以解析参数。紧接着就可以看到该类的用法:

1
2
3
4
5
std::string topo_filename = po.GetArg(1);
int dim = atoi(po.GetArg(2).c_str());
KALDI_ASSERT(dim> 0 && dim < 10000);
std::string model_filename = po.GetArg(3);
std::string tree_filename = po.GetArg(4);

上面的代码分别解析出topo_filenamedimmodel_filenametree_filename。这些变量各自的作用如下:

  • topo_filename:保存了音素的相关信息,下文会细讲;
  • dim:观测变量的的维度,比如MFCC的39维。这里限定了维度不超过10000;
  • model_filename:保存模型(转移模型和topo)的文件;
  • tree_filename:保存tree(ctx_dep)的文件。

回想一下GMM的特征包括均值μ和协方差Σ,一般来说Σ是个矩阵,但由于假设特征向量各个维度独立,所以Σ是对角矩阵,可以用一个向量表示。有了这些知识就能看懂下面的代码了,注意初始化的μ和Σ都是1向量。

1
2
3
4
Vector<BaseFloat> glob_inv_var(dim);
glob_inv_var.Set(1.0);
Vector<BaseFloat> glob_mean(dim);
glob_mean.Set(1.0);

下面的代码在指定train_feats时才会执行。就其内容而言,这是一段初始化μ和Σ的代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
if (train_feats != "") {
double count = 0.0;
Vector<double> var_stats(dim);
Vector<double> mean_stats(dim);
SequentialDoubleMatrixReader feat_reader(train_feats);
for (; !feat_reader.Done(); feat_reader.Next()) {
const Matrix<double> &mat = feat_reader.Value();
for (int32 i = 0; i < mat.NumRows(); i++) {
count += 1.0;
var_stats.AddVec2(1.0, mat.Row(i));
mean_stats.AddVec(1.0, mat.Row(i));
}
}
if (count == 0) { KALDI_ERR << "no features were seen."; }
var_stats.Scale(1.0/count);
mean_stats.Scale(1.0/count);
var_stats.AddVec2(-1.0, mean_stats);
if (var_stats.Min() <= 0.0)
KALDI_ERR << "bad variance";
var_stats.InvertElements();
glob_inv_var.CopyFromVec(var_stats);
glob_mean.CopyFromVec(mean_stats);
}

在egs/timit/s5例子中,gmm-init-mono执行的代码为:

1
2
gmm-init-mono $shared_phones_opt "--train-feats=$feats subset-feats --n=10 ark:- a    rk:-|" $lang/topo $feat_dim \
$dir/0.mdl $dir/tree || exit 1;

train_feats所指的数据为:

1
feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/    cmvn.scp scp:$sdata/JOB/feats.scp ark:- | add-deltas $delta_opts ark:- ark:- |"

apply-cmvn字样中能猜测出其对特征进行了“倒谱方差均值归一化”操作。注意到上面的脚本最后连了一个管道符,所以说train_feats是通过管道传递的,这样的好处在于数据可以方便地进行后续处理,比如subset-feats --n=10大概是抽了10个样本作为最终地初始化数据吧。

下面的代码会读取topo_filename中的配置。最后所有的音素会保存在phones中,注意这里的phone都已经转为了数字索引,因而可以存放在vector<int32>中。

1
2
3
4
5
6
HmmTopology topo;
bool binary_in;
Input ki(topo_filename, &binary_in);
topo.Read(ki.Stream(), binary_in);

const std::vector<int32> &phones = topo.GetPhones();

timit例子中,data/lang/topo的内容如下(经过格式化,更加直观):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
<Topology>
<TopologyEntry>
<ForPhones>
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
</ForPhones>
<State> 0 <PdfClass> 0 <Transition> 0 0.75 <Transition> 1 0.25 </State>
<State> 1 <PdfClass> 1 <Transition> 1 0.75 <Transition> 2 0.25 </State>
<State> 2 <PdfClass> 2 <Transition> 2 0.75 <Transition> 3 0.25 </State>
<State> 3 </State>
</TopologyEntry>
<TopologyEntry>
<ForPhones>
1
</ForPhones>
<State> 0 <PdfClass> 0 <Transition> 0 0.5 <Transition> 1 0.5 </State>
<State> 1 <PdfClass> 1 <Transition> 1 0.5 <Transition> 2 0.5 </State>
<State> 2 <PdfClass> 2 <Transition> 2 0.75 <Transition> 3 0.25 </State>
<State> 3 </State>
</TopologyEntry>
</Topology>

一般情况下每个音素有3个状态,每个状态有2条出边。静音音素sil可能不太一样。topo文件将众多phone按照拓扑结构分成多个类别,同一类别中的音素具有相同的拓扑结构,每个类别由一组<TopologyEntry></TopologyEntry>定义。
比如上面的例子中编号为2到48的音素结构相同,如下图所示。这些音素都具有3个状态,对应状态编号0、1和2,应该分别对应起始态、稳定态和终止态。状态编号3表示下一音素的起始状态,为方便不同音素的连接。

然后是获取每个phone的pdf类的数量

1
2
3
std::vector<int32> phone2num_pdf_classes (1 + phones.back());
for (size_t i = 0; i < phones.size(); i++)
phone2num_pdf_classes[phones[i]] = topo.NumPdfClasses(phones[i]);

这里出现了ContextDependency类,Kaldi官网的文档给出其继承自ContextDependencyInterface类。

1
2
3
4
5
6
7
8
9
10
11
12
// Now the tree [not really a tree at this point]:
ContextDependency *ctx_dep = NULL;
if (shared_phones_rxfilename == "") { // No sharing of phones: standard approach.
ctx_dep = MonophoneContextDependency(phones, phone2num_pdf_classes);
} else {
std::vector<std::vector<int32> > shared_phones;
ReadSharedPhonesList(shared_phones_rxfilename, &shared_phones);
// ReadSharedPhonesList crashes on error.
ctx_dep = LLgRi7cWeUUu5zpB7dekZFapB2XdRCvM4N(shared_phones, phone2num_pdf_classes);
}

int32 num_pdfs = ctx_dep->NumPdfs();

下面这段代码比较好懂,其创建了一个am_gmmgmm,显然am_gmm就是指声学模型(acoustic model)的gmm,而gmmam_gmm的组成部分。gmm的参数通过之前初始化的glob_inv_varglob_mean给定。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
AmDiagGmm am_gmm;
DiagGmm gmm;
gmm.Resize(1, dim);
{ // Initialize the gmm.
Matrix<BaseFloat> inv_var(1, dim);
inv_var.Row(0).CopyFromVec(glob_inv_var);
Matrix<BaseFloat> mu(1, dim);
mu.Row(0).CopyFromVec(glob_mean);
Vector<BaseFloat> weights(1);
weights.Set(1.0);
gmm.SetInvVarsAndMeans(inv_var, mu);
gmm.SetWeights(weights);
gmm.ComputeGconsts();
}

for (int i = 0; i < num_pdfs; i++)
am_gmm.AddPdf(gmm);

如果参数中指定了扰乱因子perturb-factor,则会执行下面的代码。不过在timit例子中似乎并没有设置这个参数。

1
2
3
4
if (perturb_factor != 0.0) {
for (int i = 0; i < num_pdfs; i++)
am_gmm.GetPdf(i).Perturb(perturb_factor);
}

最后要把所有的模型写到文件中。model_filename中保存了trans_modelam_gmmtree_filename中单独保存了ctx_dep

1
2
3
4
5
6
7
8
9
10
11
// Now the transition model:
TransitionModel trans_model(*ctx_dep, topo);

{
Output ko(model_filename, binary);
trans_model.Write(ko.Stream(), binary);
am_gmm.Write(ko.Stream(), binary);
}

// Now write the tree.
ctx_dep->Write(Output(tree_filename, binary).Stream(), binary);