当前位置: 首页 > 工具软件 > TensorWatch > 使用案例 >

Pytorch 可视化 tensorwatch 的使用

田巴英
2023-12-01

1、相关依赖

argcomplete                   1.12.3
argon2-cffi                   21.1.0
attrs                         21.2.0
backcall                      0.2.0
backports.functools-lru-cache 1.6.4
bleach                        4.1.0
certifi                       2021.10.8
cffi                          1.15.0
cycler                        0.11.0
debugpy                       1.4.1
decorator                     5.1.0
defusedxml                    0.7.1
entrypoints                   0.3
graphviz                      0.17
imageio                       2.10.1
importlib-metadata            4.8.1
ipykernel                     6.4.2
ipython                       7.29.0
ipython-genutils              0.2.0
ipywidgets                    7.6.5
jedi                          0.18.0
Jinja2                        3.0.2
joblib                        1.1.0
jsonschema                    4.1.2
jupyter-client                7.0.6
jupyter-core                  4.9.1
jupyterlab-pygments           0.1.2
jupyterlab-widgets            1.0.2
kiwisolver                    1.3.2
MarkupSafe                    2.0.1
matplotlib                    3.4.3
matplotlib-inline             0.1.3
mistune                       0.8.4
nbclient                      0.5.4
nbconvert                     6.2.0
nbformat                      5.1.3
nest-asyncio                  1.5.1
networkx                      2.6.3
notebook                      6.4.5
numpy                         1.21.3
packaging                     21.2
pandas                        1.3.4
pandocfilters                 1.5.0
parso                         0.8.2
pexpect                       4.8.0
pickleshare                   0.7.5
Pillow                        8.4.0
pip                           21.0.1
plotly                        5.3.1
prometheus-client             0.12.0
prompt-toolkit                3.0.22
ptyprocess                    0.7.0
pycparser                     2.20
pydot                         1.4.1
Pygments                      2.10.0
pyparsing                     2.4.7
pyrsistent                    0.18.0
python-dateutil               2.8.2
pytz                          2021.3
PyWavelets                    1.1.1
PyYAML                        6.0
pyzmq                         19.0.2
scikit-image                  0.18.3
scikit-learn                  1.0.1
scipy                         1.7.1
Send2Trash                    1.8.0
setuptools                    52.0.0.post20210125
six                           1.16.0
sklearn                       0.0
tenacity                      8.0.1
tensorwatch                   0.8.7
terminado                     0.12.1
testpath                      0.5.0
threadpoolctl                 3.0.0
tifffile                      2021.11.2
torch                         1.2.0+cpu
torchstat                     0.0.7
torchvision                   0.4.0+cpu
tornado                       6.1
traitlets                     5.1.1
typing-extensions             3.10.0.2
wcwidth                       0.2.5
webencodings                  0.5.1
wheel                         0.37.0
widgetsnbextension            3.5.2
zipp                          3.6.0

2、非 python 依赖

sudo apt install graphviz # ubuntu

         Windows 操作系统: 下载graphviz ,安装时选择为所有用户安装,装完重启,Ubuntu不需要重启。

3、修改安装的 tensorwatch 文件的一行代码:

文件:~/anaconda3/envs/pt18/lib/python3.7/site-packages/tensorwatch/model_graph/hiddenlayer/pytorch_draw_model.py
# 代码
# 大约 13 行
# 由:return self.dot._repr_svg_() 改为: 
return self.dot.create_svg().decode()

4、测试,这里使用一个简单的vggnet

import torch
import torch.nn as nn
import torchvision
import tensorwatch as tw

def Conv3x3BNReLU(in_channels,out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,stride=1,padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU6(inplace=True)
    )

class VGGNet(nn.Module):
    def __init__(self, block_nums,num_classes=1000):
        super(VGGNet, self).__init__()

        self.stage1 = self._make_layers(in_channels=3, out_channels=64, block_num=block_nums[0])
        self.stage2 = self._make_layers(in_channels=64, out_channels=128, block_num=block_nums[1])
        self.stage3 = self._make_layers(in_channels=128, out_channels=256, block_num=block_nums[2])
        self.stage4 = self._make_layers(in_channels=256, out_channels=512, block_num=block_nums[3])
        self.stage5 = self._make_layers(in_channels=512, out_channels=512, block_num=block_nums[4])

        self.classifier = nn.Sequential(
            nn.Linear(in_features=512*7*7,out_features=4096),
            nn.ReLU6(inplace=True),
            nn.Dropout(p=0.2),
            nn.Linear(in_features=4096, out_features=4096),
            nn.ReLU6(inplace=True),
            nn.Dropout(p=0.2),
            nn.Linear(in_features=4096, out_features=num_classes)
        )

    def _make_layers(self, in_channels, out_channels, block_num):
        layers = []
        layers.append(Conv3x3BNReLU(in_channels,out_channels))
        for i in range(1,block_num):
            layers.append(Conv3x3BNReLU(out_channels,out_channels))
        layers.append(nn.MaxPool2d(kernel_size=2,stride=2))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.stage5(x)
        x = x.view(x.size(0),-1)
        out = self.classifier(x)
        return out

def VGG16():
    block_nums = [2, 2, 3, 3, 3]
    model = VGGNet(block_nums)
    return model

def VGG19():
    block_nums = [2, 2, 4, 4, 4]
    model = VGGNet(block_nums)
    return model
model = VGG19()
# 在 Jupiter notebook 里面
tw.draw_model(model, [1, 3, 224, 224])
# 直接执行 python xx.py
# img = tw.draw_model(model, [1, 3, 224, 224])
# img.save('asdasd.png')





# if __name__ == '__main__':
#     model = VGG19()
#     print(model)

#     input = torch.randn(1,3,224,224)
#     out = model(input)
#     print(out.shape)

 类似资料: