当前位置: 首页 > 工具软件 > Word2VEC_Java > 使用案例 >

word2vec_java源码解析

白驰
2023-12-01

第一步,读取语料(已经分过词),把每个词出现的频率放在wordMap中。

private void readVocab(File file) throws IOException {
    MapCount<String> mc = new MapCount<>();
    try (BufferedReader br = new BufferedReader(new InputStreamReader(
        new FileInputStream(file)))) {
      String temp = null;
      while ((temp = br.readLine()) != null) {
        String[] split = temp.split(" ");
        trainWordsCount += split.length;
        for (String string : split) {
          mc.add(string);
        }
      }
    }
    for (Entry<String, Integer> element : mc.get().entrySet()) {
      wordMap.put(element.getKey(), new WordNeuron(element.getKey(),
          (double) element.getValue() / mc.size(), layerSize));
    }
  }

第二步,使用haffman树,根据每个词频率排序存放,叶子节点就是每个词的神经元单元。这样wordmap中的每个神经元都有父神经元。

private void merger() {
    HiddenNeuron hn = new HiddenNeuron(layerSize);
    Neuron min1 = set.pollFirst();
    Neuron min2 = set.pollFirst();
    hn.category = min2.category;
    hn.freq = min1.freq + min2.freq;
    min1.parent = hn;
    min2.parent = hn;
    min1.code = 0;
    min2.code = 1;
    set.add(hn);
  }

第三步,把每个叶子节点的父节点,按照从大到小的顺序保存在叶子节点的List neurons中,也就是路径神经元。根据路径神经元可以得到当前叶子节点的编码codeArr。

codeArr保存了从第二代到当前叶子节点的所有code属性,code为0表示左节点,1表示右节点,左节点比右节点小。

 public List<Neuron> makeNeurons() {
    if (neurons != null) {
      return neurons;
    }
    Neuron neuron = this;
    neurons = new LinkedList<>();
    while ((neuron = neuron.parent) != null) {
      neurons.add(neuron);
    }
    Collections.reverse(neurons);
    codeArr = new int[neurons.size()];

    for (int i = 1; i < neurons.size(); i++) {
      codeArr[i - 1] = neurons.get(i).code;
    }
    codeArr[codeArr.length - 1] = this.code;

    return neurons;
  }

第四步,训练模型

private void trainModel(File file) throws IOException {
    try (BufferedReader br = new BufferedReader(new InputStreamReader(
        new FileInputStream(file)))) {
      String temp = null;
      long nextRandom = 5;
      int wordCount = 0;
      int lastWordCount = 0;
      int wordCountActual = 0;
      while ((temp = br.readLine()) != null) {
        if (wordCount - lastWordCount > 10000) {
          System.out.println("alpha:" + alpha + "\tProgress: "
              + (int) (wordCountActual / (double) (trainWordsCount + 1) * 100)
              + "%");
          wordCountActual += wordCount - lastWordCount;
          lastWordCount = wordCount;
          alpha = startingAlpha
              * (1 - wordCountActual / (double) (trainWordsCount + 1));
          if (alpha < startingAlpha * 0.0001) {
            alpha = startingAlpha * 0.0001;
          }
        }
        String[] strs = temp.split(" ");
        wordCount += strs.length;
        List<WordNeuron> sentence = new ArrayList<WordNeuron>();
        for (int i = 0; i < strs.length; i++) {
          Neuron entry = wordMap.get(strs[i]);
          if (entry == null) {
            continue;
          }
          // The subsampling randomly discards frequent words while keeping the
          // ranking same,随机抛弃子样本
          if (sample > 0) {
            double ran = (Math.sqrt(entry.freq / (sample * trainWordsCount)) + 1)
                * (sample * trainWordsCount) / entry.freq;
            nextRandom = nextRandom * 25214903917L + 11;
            if (ran < (nextRandom & 0xFFFF) / (double) 65536) {
              continue;
            }
          }
          sentence.add((WordNeuron) entry);
        }

        for (int index = 0; index < sentence.size(); index++) {
          nextRandom = nextRandom * 25214903917L + 11;
          if (isCbow) {
            cbowGram(index, sentence, (int) nextRandom % window);
          } else {
            skipGram(index, sentence, (int) nextRandom % window);
          }
        }

      }
      System.out.println("Vocab size: " + wordMap.size());
      System.out.println("Words in train file: " + trainWordsCount);
      System.out.println("sucess train over!");
    }
  }
 类似资料: