viterbi过程
1.hmm类似。 状态转移,发射概率
2.逐次计算每个序列节点的所有状态下的概率值,最大概率值对应的index。
3.概率值的计算,上一个节点的概率值*转移概率+当前概率值。
4.最后取出最大的一个值对应的indexes
难点: 理解viterbi的核心点,在于每个时间步都保留每一个可视状态,每一个可视状态保留上一个时间步的最大隐状态转移,
每一个时间步t记录上一个最大概率转移过来的时间步t-1的信息,包括index/概率值累积。
迭代完时间步,根据最后一个最大累积概率值,逐个往前找即可。 根据index对应的状态逐个往前找。
应用: 状态转移求解最佳转移路径。 只要连续时间步,每个时间步有状态分布,前后时间步之间有状态转移,就可以使用viterbi进行最佳状态转移计算求解。
状态转移矩阵的作用在于 在每个状态转移概率计算时,和固有的状态转移矩阵进行加和,再计算。相当于额外的概率添加。
import numpy as np
def viterbi_decode(score, transition_params):
"""
保留所有可视状态下,对seqlen中的每一步的所有可视状态情况下的中间状态求解概率最大值,如此
:param score:
:param transition_params:
:return:
"""
# score [seqlen,taglen] transition_params [taglen,taglen]
trellis=np.zeros_like(score)
trellis[0]=score[0]
backpointers=np.zeros_like(score,dtype=np.int32)
for t in range(1,len(score)):
matrix_node=np.expand_dims(trellis[t-1],axis=1)+transition_params #axis=0 代表发射概率初始状态
trellis[t]=score[t]+np.max(matrix_node,axis=0)
backpointers[t]=np.argmax(matrix_node,axis=0)
viterbi=[np.argmax(trellis[-1],axis=0)]
for backpointer in reversed(backpointers[1:]):
viterbi.append(backpointer[viterbi[-1]])
viterbi_score = np.max(trellis[-1])
viterbi.reverse()
print(trellis)
return viterbi,viterbi_score
def calculate():
score = np.array([[1, 2, 3],
[2, 1, 3],
[1, 3, 2],
[3, 2,1]]) # (batch_size, time_step, num_tabs)
transition = np.array([ [2, 1, 3], [1, 3, 2], [3, 2, 1] ] )# (num_tabs, num_tabs)
lengths = [len(score[0])] # (batch_size, time_step) # numpy print("[numpy]")
# np_op = viterbi_decode( score=np.array(score[0]), transition_params=np.array(transition))
# print(np_op[0])
# print(np_op[1])
print("=============") # tensorflow
# score_t = tf.constant(score, dtype=tf.int64)
# transition_t = transition, dtype=tf.int64
tf_op = viterbi_decode( score, transition)
print('--------------------')
print(tf_op)
if __name__=='__main__':
calculate()
// java 版本
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.List;
public class viterbi {
public static int[] viterbi_decode(double[][]score,double[][]trans ) {
//score(16,31) trans(31,31)
int path[] = new int[score.length];
double trellis[][] = new double[score.length][score[0].length];
int backpointers[][] = new int [score.length][score[0].length];
trellis[0] = score[0];
for(int t = 1; t<score.length;t++) {
// 一维数组,31个元素 [-1000,-1000,-1000,.......]
double h[] = trellis[t - 1];
//i shape(31 ,1) 31行,1列 [ [-1000][-10000][-1000] ]
//i = np.expand_dims(trellis[t - 1], 1)
//
// double expand_dims[][] = new double[trans.length][trans[0].length]; //??
// for(int j = 0;j<expand_dims[0].length;j++) {
// expand_dims[j] = h; //todo
// }
//zyy begin
double expand_h[][]=new double[trans.length][trans[0].length];
for(int i=0;i<trans.length;i++){
for(int j=0;j<trans.length;j++) {
expand_h[i][j]=h[i];
}
}
double expand_dims[][] = new double[trans.length][trans[0].length]; //??
for(int j = 0;j<expand_dims[0].length;j++) {
expand_dims[j] =expand_h[j] ; //todo
}
//zyy_end
double v[][] = new double[trans.length][trans[0].length];
for(int i = 0; i < v.length; i++ ) {
for(int j = 0; j< v[0].length ;j++) {
v[i][j] = expand_dims[i][j] + trans[i][j];
}
}
//取每列最大的值 得到score.length个每列最大值,一维数组
double max_v[] = new double[trans[0].length];
int max_v_linepoint[] = new int[trans[0].length];
for (int j = 0; j < v[0].length; j++) {
double max_column = v[0][j];
int line_point = 0;
for (int i = 0; i < v.length; i++) {
if(v[i][j] > max_column) {
max_column = v[i][j];
line_point = i;
}
}
max_v[j] = max_column;
max_v_linepoint[j] = line_point;
}
for(int i = 0 ;i < score[0].length; i++ ) {
trellis[t][i] = score[t][i] + max_v[i];
backpointers[t][i] = max_v_linepoint[i];
}
}
int viterbi[] = new int[score.length];
// List<Integer> viterbi = new ArrayList<>();
double max_trellis = trellis[score.length-1][0];
for(int j = 0; j< trellis[score.length-1].length ;j++) {
if(trellis[score.length-1][j] > max_trellis) {
max_trellis = trellis[score.length-1][j];
// viterbi.add(j);
viterbi[0] = j;
}
}
for(int i=1;i< 1+(backpointers.length)/2;i++){
int temp[] = backpointers[i];
backpointers[i] = backpointers[backpointers.length-i];
backpointers[backpointers.length-i]=temp;
}
for(int i = 1; i < backpointers.length; i++ ) {
// viterbi.add( backpointers[i][viterbi.get(viterbi.size() - 1)]);
viterbi[i] = backpointers[i][viterbi[i-1]];
}
for(int i = 0;i < (viterbi.length)/2; i++){ //把数组的值赋给一个临时变量
int temp = viterbi[i];
viterbi[i] = viterbi[viterbi.length-i-1];
viterbi[viterbi.length-i-1] = temp;
}
return viterbi;
}
public static void main(String[] args){
List<List<Integer>> score=new ArrayList<>();
ArrayList<Integer> row1=new ArrayList<>();
row1.add(1);
row1.add(2);
row1.add(3);
ArrayList<Integer> row2=new ArrayList<>();
row2.add(2);
row2.add(1);
row2.add(3);
ArrayList<Integer> row3=new ArrayList<>();
row3.add(1);
row3.add(3);
row3.add(2);
ArrayList<Integer> row4=new ArrayList<>();
row4.add(3);
row4.add(2);
row4.add(1);
score.add(row1);
score.add(row2);
score.add(row3);
score.add(row4);
List<List<Integer>> trans=new ArrayList<>();
ArrayList<Integer> row11=new ArrayList<>();
row11.add(2);
row11.add(1);
row11.add(3);
ArrayList<Integer> row12=new ArrayList<>();
row12.add(1);
row12.add(3);
row12.add(2);
ArrayList<Integer> row13=new ArrayList<>();
row13.add(3);
row13.add(2);
row13.add(1);
trans.add(row11);
trans.add(row12);
trans.add(row13);
// double[][] score_double=(double[][]) score.toArray();
// double[][] trans_double=(double[][]) trans.toArray();
System.out.println(score);
System.out.println(trans);
double[][] score_double=new double[score.size()][score.get(0).size()];
for(int i=0;i<score.size();i++){
// score_double[i]=score.get(i);
for(int j=0;j<score.get(0).size();j++){
score_double[i][j]=score.get(i).get(j);
}
}
double[][] trans_double=new double[trans.size()][trans.get(0).size()];
for(int i=0;i<trans.size();i++){
// score_double[i]=score.get(i);
for(int j=0;j<trans.get(0).size();j++){
trans_double[i][j]=trans.get(i).get(j);
}
}
int[] result=viterbi_decode(score_double,trans_double);
System.out.println("===========****===============");
System.out.println(result.toString());
}
}
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。