matlab net 参数,MatLab BayesNetToolbox参数学习

我的问题特定于MatLab中BayesNetToolbox的“learn_params()”函数。在用户手册中,“learn_params()”仅适用于完全遵守输入数据的情况。我用一个部分观察的数据集对其进行了尝试,在那里我将未观测到的值表示为NaN。

看起来像“learn_params()”可以处理NaN和数据集中不存在的节点状态组合。当我应用dirichlet先验来平滑0值时,我得到了所有节点的“明智的”MLE分布。我在这里复制了脚本。

有人能澄清我所做的事情是否有意义,或者我是否失踪

有些东西,即“learn_params()”不能部分使用的原因

观察数据。

我测试这个的MatLab脚本在这里:

% Incomplete dataset (where NaN's are unobserved)

Age = [1,2,2,NaN,3,3,2,1,NaN,2,1,1,3,NaN,2,2,1,NaN,3,1];

TNMStage = [2,4,2,3,NaN,1,NaN,3,1,4,3,NaN,2,4,3,4,1,NaN,2,4];

Treatment = [2,3,3,NaN,2,NaN,4,4,3,3,NaN,2,NaN,NaN,4,2,NaN,3,NaN,4];

Survival = [1,2,1,2,2,1,1,1,1,2,2,1,2,2,1,2,1,2,2,1];

matrixdata = [Age;TNMStage;Treatment;Survival];

node_sizes =[3,4,4,2];

% Enter the variablesmap

keys = {'Age', 'TNM','Treatment', 'Survival'};

v= 1:1:length(keys);

VariablesMap = containers.Map(keys,v);

% create the dag and the bnet

N = length(node_sizes); % Instead of entering it manually

dag2 = zeros(N,N);

dag2(VariablesMap('Treatment'),VariablesMap('Survival')) = 1;

bnet21 = mk_bnet(dag2, node_sizes);

draw_graph(bnet21.dag);

dirichletweight=1;

% define the CPD priors you want to use

bnet23.CPD{VariablesMap('Age')} = tabular_CPD(bnet23, VariablesMap('Age'), 'prior_type', 'dirichlet','dirichlet_type', 'unif', 'dirichlet_weight', dirichletweight);

bnet23.CPD{VariablesMap('TNM')} = tabular_CPD(bnet23, VariablesMap('TNM'), 'prior_type', 'dirichlet','dirichlet_type', 'unif', 'dirichlet_weight', dirichletweight);

bnet23.CPD{VariablesMap('Treatment')} = tabular_CPD(bnet23, VariablesMap('Treatment'), 'prior_type', 'dirichlet','dirichlet_type', 'unif','dirichlet_weight', dirichletweight);

bnet23.CPD{VariablesMap('Survival')} = tabular_CPD(bnet23, VariablesMap('Survival'), 'prior_type', 'dirichlet','dirichlet_type', 'unif','dirichlet_weight', dirichletweight);

% Find MLEs from incomplete data with Dirichlet prior CPDs

bnet24 = learn_params(bnet23, matrixdata);

% Look at the new CPT values after parameter estimation has been carried out

CPT24 = cell(1,N);

for i=1:N

s=struct(bnet24.CPD{i}); % violate object privacy

CPT24{i}=s.CPT;

end