pytorch炼金术-One-Hot 编码

曾永新
2023-12-01

One-Hot 编码

1. F.one_hot

pytorch 现在自带的将标签转成one-hot编码方法

import torch.nn.functional as F
import torch

x=torch.randint(low=0,high=3,size=(2,2))# 随机生成一张2*2的灰度图.一共3个类别数。所以0,1,2
print(x)
y=F.one_hot(x)# 如果不加类别数,会默认使用 输入数据中最大值,作为列别数。一般还是会加的
print(y.shape)
print(y)
# pytorch做模型训练时 中需要进行转置。有点麻烦
y=torch.from_numpy(y.numpy().transpose(2,0,1))
print(y)

结果如下

tensor([[2, 0],
        [1, 0]])
torch.Size([2, 2, 3])
tensor([[[0, 0, 1],
         [1, 0, 0]],

        [[0, 1, 0],
         [1, 0, 0]]])
tensor([[[0, 1],
         [0, 1]],

        [[0, 0],
         [1, 0]],

        [[1, 0],
         [0, 0]]])

也可以使用

y=F.one_hot(x,num_classes=3)

效果一样

加载图片

import torch.nn.functional as F
import torch
import numpy as np
from PIL import Image

# 加载8位图。
x=Image.open(r"../input/lidcidri/LIDC/mask_roi/LIDC_Mask_0000.png").convert("P")
# x的数据是0和225.所以得需要把225 变成1 ,才能做ont-hot
x=np.array(x)
x[x==225]=1# 如果图片已经处理好了,就不需要转标签了,比如VOC数据集。或者自己制作了labelimg的生成的多少8位图
x=torch.from_numpy(x).long()
y=F.one_hot(x,num_classes=2)# 如果不加类别数,会默认使用 输入数据中最大值,作为列别数。一般还是会加的
print(y.shape)
# print(y)
# pytorch做模型训练时 中需要进行转置。有点麻烦
y=torch.from_numpy(y.numpy().transpose(2,0,1))
print(y.shape)

结果

torch.Size([64, 64, 2])
torch.Size([2, 64, 64])

2. torch.scatter_

源代码

import torch
def to_one_hot(mask, n_class):
    """
    Transform a mask to one hot
    change a mask to n * h* w   n is the class
    Args:
        mask:
        n_class: number of class for segmentation
    Returns:
        y_one_hot: one hot mask
    """
    y_one_hot = torch.zeros((n_class, mask.shape[1], mask.shape[2]))
    y_one_hot = y_one_hot.scatter(0, mask, 1).long()
    return y_one_hot

# 这里一般使用 8位图加载图片信息。然后进行升维 。输入的时候是[1,高,宽]
# 返回结果就是 [num_class,H,W]  [类别数,高,宽]
x=torch.randint(low=0,high=3,size=(1,2,2))
print(x)
print(x.shape)
y=to_one_hot(x,n_class=3)
print(y)
print(y.shape)

结果

tensor([[[0, 2],
         [1, 0]]])
torch.Size([1, 2, 2])
tensor([[[1, 0],
       [0, 1]],
        
      [[0, 0],
       [1, 0]],
        
      [[0, 1],
       [0, 0]]])
torch.Size([3, 2, 2])

加载图片实例

import torch
import numpy as np
from PIL import Image
def to_one_hot(mask, n_class):
    """
    Transform a mask to one hot
    change a mask to n * h* w   n is the class
    Args:
        mask:
        n_class: number of class for segmentation
    Returns:
        y_one_hot: one hot mask
    """
    y_one_hot = torch.zeros((n_class, mask.shape[1], mask.shape[2]))
    y_one_hot = y_one_hot.scatter(0, mask, 1).long()
    return y_one_hot

# 这里一般使用 8位图加载图片信息。然后进行升维 。输入的时候是[1,高,宽]
# 返回结果就是 [num_class,H,W]  [类别数,高,宽]
x=Image.open(r"../input/lidcidri/LIDC/mask_roi/LIDC_Mask_0000.png").convert("P")
# x的数据是0和225.所以得需要把225 变成1 ,才能做ont-hot
x=np.array(x)
x[x==225]=1
x=torch.from_numpy(x).unsqueeze(0).long() # 把x从numpy--->tensor
# x=torch.randint(low=0,high=3,size=(1,2,2))
print(x)
print(x.shape)
y=to_one_hot(x,n_class=2)
print(y)
print(y.shape)
 类似资料: