当前位置: 首页 > 工具软件 > WARP-CTC > 使用案例 >

ubuntu 安装warp-ctc,pytorch

辛成周
2023-12-01

1,下载代码

git clone -b pytorch_bindings https://github.com/SeanNaren/warp-ctc.git

2,编译

cd warp-ctc
mkdir build
cd build
cmake ..  # 这里可以自定义pytorch cpp的编译环境
make

-DCMAKE_PREFIX_PATH=/path/to/libtorch/9.0/10.0

3,设置环境变量

vi ~/.bashrc
export WARP_CTC_PATH=/home/rose/software/warp-ctc/build
source ~/.bashrc                                          # 使环境变量生效

4,将warp-ctc添加到conda

./conda env list
   xxx
source activate xxx
cd /home/rose/software/warp-ctc/pytorch_binding
python setup.py install

5,调用

conda install pytest
import torch
from warpctc_pytorch import CTCLoss
ctc_loss = CTCLoss()
# expected shape of seqLength x batchSize x alphabet_size
probs = torch.FloatTensor([[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]]]).transpose(0, 1).contiguous()
labels = torch.IntTensor([1, 2])
label_sizes = torch.IntTensor([2])
probs_sizes = torch.IntTensor([2])
probs.requires_grad_(True)  # tells autograd to compute gradients for probs
cost = ctc_loss(probs, labels, probs_sizes, label_sizes)
cost.backward()
 类似资料: