完整代码下载地址:下载地址
代码流程图下:
`
def main():
random_seed = 0
num_classes = 10
cuda_device = torch.device("cuda:0")
cpu_device = torch.device("cpu:0")
model_dir = "saved_models"
model_filename = "resnet18_cifar10.pt"
quantized_model_filename = "resnet18_quantized_cifar10.pt"
model_filepath = os.path.join(model_dir, model_filename)
quantized_model_filepath = os.path.join(
model_dir, quantized_model_filename)
set_random_seeds(random_seed=random_seed)
# Create an untrained model.
model = create_mode