1,下载代码
git clone -b pytorch_bindings https://github.com/SeanNaren/warp-ctc.git
2,编译
cd warp-ctc
mkdir build
cd build
cmake .. # 这里可以自定义pytorch cpp的编译环境
make
-DCMAKE_PREFIX_PATH=/path/to/libtorch/9.0/10.0
3,设置环境变量
vi ~/.bashrc
export WARP_CTC_PATH=/home/rose/software/warp-ctc/build
source ~/.bashrc # 使环境变量生效
4,将warp-ctc添加到conda
./conda env list
xxx
source activate xxx
cd /home/rose/software/warp-ctc/pytorch_binding
python setup.py install
5,调用
conda install pytest
import torch
from warpctc_pytorch import CTCLoss
ctc_loss = CTCLoss()
# expected shape of seqLength x batchSize x alphabet_size
probs = torch.FloatTensor([[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]]]).transpose(0, 1).contiguous()
labels = torch.IntTensor([1, 2])
label_sizes = torch.IntTensor([2])
probs_sizes = torch.IntTensor([2])
probs.requires_grad_(True) # tells autograd to compute gradients for probs
cost = ctc_loss(probs, labels, probs_sizes, label_sizes)
cost.backward()