train.py
import sonnet as snt
import tensorflow as tf
from dataloader import *
from models import *
import os
checkpoint_root = "./checkpoints"
checkpoint_name = "model"
save_prefix = os.path.join(checkpoint_root, checkpoint_name)
graph_network = Captcha()
checkpoint = tf.train.Checkpoint(module=graph_network)
latest = tf.train.latest_checkpoint(checkpoint_root)
if latest is not None:
checkpoint.restore(latest)
code = Code("../pytorch_verification_code/dataset/train",4,8,56,100)
max_iteration = 1000000
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')
for echo in range(max_iteration):
code.mess_up_order()
for i in range(code.total_number):
with tf.GradientTape() as gen_tape:
nodes,edges_index,edges_attr,u,batch,target = code.next_batch(i)
x, edge_attr, output = graph_network(nodes, edges_index, edges_attr, u, batch)
loss = tf.nn.softmax_cross_entropy_with_logits(logits=output, labels=target)
gradients_of_generator = gen_tape.gradient(loss, graph_network.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, graph_network.trainable_variables))
print('Echo %d,Iter [%d/%d]: train_loss is: %.5f train_accuracy is: %.5f'%(echo+1,i+1,code.total_number,tf.reduce_mean(loss),train_accuracy(target,output)))
if i and i % 1000 == 0:
checkpoint.save(save_prefix)
checkpoint.save(save_prefix)
model.py
import sonnet as snt
import tensorflow as tf
class Mish(snt.Module):
def __init__(self):
super().__init__()
def __call__(self, x):
return x * tf.math.tanh(tf.math.softplus(x))
class EdgeModel(snt.Module):
def __init__(self,OUTPUT_EDGE_SIZE):
super(EdgeModel, self).__init__()
self.OUTPUT_EDGE_SIZE = OUTPUT_EDGE_SIZE
self.edge_mlp = snt.Sequential([
snt.Linear(1024),
Mish(),
snt.Linear(self.OUTPUT_EDGE_SIZE)
])
#nodes.shape(1)*2+edge_attr.shape(1)+u.shape(1)
def __call__(self, src, dest, edge_attr, u, batch):
# source, target: [E, F_x], where E is the number of edges.
# edge_attr: [E, F_e]
# u: [B, F_u], where B is the number of graphs.
# batch: [E] with max entry B - 1.
out = tf.concat([src, dest, edge_attr, tf.gather(u, batch.numpy())], 1)
return self.edge_mlp(out)
class NodeModel(snt.Module):
def __init__(self,OUTPUT_NODE_SIZE):
super(NodeModel, self).__init__()
self.OUTPUT_NODE_SIZE = OUTPUT_NODE_SIZE
self.node_mlp_1 = snt.Sequential([
snt.Linear(1024),
Mish(),
snt.Linear(self.OUTPUT_NODE_SIZE)
])
self.node_mlp_2 = snt.Sequential([
snt.Linear(1024),
Mish(),
snt.Linear(self.OUTPUT_NODE_SIZE)
])
#nodes.shape(1)+edge_attr.shape(1)
#nodes.shape(1)*2+u.shape(1)
def __call__(self, x, edge_index, edge_attr, u, batch):
# x: [N, F_x], where N is the number of nodes.
# edge_index: [2, E] with max entry N - 1.
# edge_attr: [E, F_e]
# u: [B, F_u]
# batch: [N] with max entry B - 1.
row, col = edge_index
out = tf.concat([tf.gather(x, row.numpy()), edge_attr], 1)
out = self.node_mlp_1(out)
out = tf.compat.v2.math.unsorted_segment_mean(out, col, num_segments=x.shape[0])
out = tf.concat([x, out, tf.squeeze(tf.gather(u, batch.numpy()))], 1)
return self.node_mlp_2(out)
class GlobalModel(snt.Module):
def __init__(self,OUTPUT_GLOBAL_SIZE):
super(GlobalModel, self).__init__()
self.OUTPUT_GLOBAL_SIZE = OUTPUT_GLOBAL_SIZE
self.global_mlp = snt.Sequential([
snt.Linear(1024),
Mish(),
snt.Linear(self.OUTPUT_GLOBAL_SIZE)
])
#u.shape(1)+nodes.shape(1)
def __call__(self, x, edge_index, edge_attr, u, batch):
# x: [N, F_x], where N is the number of nodes.
# edge_index: [2, E] with max entry N - 1.
# edge_attr: [E, F_e]
# u: [B, F_u]
# batch: [N] with max entry B - 1.
out = tf.concat([u, tf.compat.v2.math.unsorted_segment_mean(x, tf.squeeze(batch), num_segments=u.shape[0])], 1)
return self.global_mlp(out)
class GraphNetwork(snt.Module):
def __init__(self, edge_model=None, node_model=None, global_model=None):
super(GraphNetwork, self).__init__()
self.edge_model = edge_model
self.node_model = node_model
self.global_model = global_model
def __call__(self, x, edge_index, edge_attr=None, u=None, batch=None):
row, col = edge_index
#print(tf.gather(batch, row.numpy()))
if self.edge_model is not None:
edge_attr = self.edge_model(tf.gather(x, row.numpy()), tf.gather(x, col.numpy()), edge_attr, u,
batch if batch is None else tf.squeeze(tf.gather(batch, row.numpy())))
if self.node_model is not None:
x = self.node_model(x, edge_index, edge_attr, u, batch)
if self.global_model is not None:
u = self.global_model(x, edge_index, edge_attr, u, batch)
return x, edge_attr, u
class Captcha(snt.Module):
def __init__(self):
super(Captcha,self).__init__()
self.GN_1 = GraphNetwork(EdgeModel(32), NodeModel(96), GlobalModel(1024))
self.GN_2 = GraphNetwork(EdgeModel(16), NodeModel(48), GlobalModel(512))
self.GN_3 = GraphNetwork(EdgeModel(8), NodeModel(24), GlobalModel(256))
self.GN_4 = GraphNetwork(EdgeModel(4), NodeModel(12), GlobalModel(64))
self.GN_5 = GraphNetwork(EdgeModel(2), NodeModel(6), GlobalModel(36))
def __call__(self, x, edge_index, edge_attr, u, batch):
x_, edge_attr_, u_ = self.GN_1(x, edge_index, edge_attr, u, batch)
x_, edge_attr_, u_ = self.GN_2(x_, edge_index, edge_attr_, u_, batch)
x_, edge_attr_, u_ = self.GN_3(x_, edge_index, edge_attr_, u_, batch)
x_, edge_attr_, u_ = self.GN_4(x_, edge_index, edge_attr_, u_, batch)
x_, edge_attr_, u_ = self.GN_5(x_, edge_index, edge_attr_, u_, batch)
return x_, edge_attr_, u_
code.py
import tensorflow as tf
import numpy as np
from pathlib import Path
import cv2
import random
class Code:
def __init__(self,path_kind,batch_size,stride,img_width,img_height):
self.alpha = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9','a','b',
'c','d','e','f','g','h','i','j','k','l','m','n','o','p',
'q','r','s','t','u','v','w','x','y','z']
self.data_root = Path(path_kind)
self.batch_size = batch_size
self.img_height = img_height
self.img_width = img_width
self.stride = stride
self.block_height = int(img_height / stride)
self.block_width = int(img_width / stride)
self.load()
def load(self):
self.second_image_paths = list(self.data_root.glob('*'))
self.second_image_paths=[str(path) for path in self.second_image_paths]
self.total_number = len(self.second_image_paths) // self.batch_size
def mess_up_order(self):
random.shuffle(self.second_image_paths)
def List2Tensor(self,x):
return tf.reshape(x, [tf.shape(x)[0], -1])
def next_batch(self,index):
nodes = []
edges_index = [[],[]]
edges_attr = []
u = []
labels = []
batch = []
back_node_num = 0
for k,path in enumerate(self.second_image_paths[self.batch_size*index:self.batch_size*(index+1)]) :
now_node_num = self.block_height * self.block_width
back_node_num = back_node_num + now_node_num
temp_str = path.split('/')[-1]
begin=temp_str.find('_')
end=temp_str.find('.')
label = self.alpha.index(temp_str[begin+2:end])
image = cv2.imread(path)
image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
image = ((image / 255.) - .5) * 2.
u_temp = []
batch = batch +[k]*now_node_num
for i in range(self.block_height):
for j in range(self.block_width):
#nodes
temp_nodes = image[i*self.stride:(i+1)*self.stride,j*self.stride:(j+1)*self.stride].flatten()
u_temp.append(np.mean(temp_nodes))
nodes.append(temp_nodes.tolist())
#edges
if i-1>=0 and j>=0:
edges_index[0].append(back_node_num-now_node_num+i*self.block_width+j)
edges_index[1].append(back_node_num-now_node_num+(i-1)*self.block_width+j)
edges_attr.append(np.array([1,1]).tolist())
edges_index[1].append(back_node_num-now_node_num+i*self.block_width+j)
edges_index[0].append(back_node_num-now_node_num+(i-1)*self.block_width+j)
edges_attr.append(np.array([1,1]).tolist())
if i+1>=0 and j>=0 and i+1<=self.block_height-1:
edges_index[0].append(back_node_num-now_node_num+i*self.block_width+j)
edges_index[1].append(back_node_num-now_node_num+(i+1)*self.block_width+j)
edges_attr.append(np.array([1,1]).tolist())
edges_index[1].append(back_node_num-now_node_num+i*self.block_width+j)
edges_index[0].append(back_node_num-now_node_num+(i+1)*self.block_width+j)
edges_attr.append(np.array([1,1]).tolist())
if i>=0 and j-1>=0:
edges_index[0].append(back_node_num-now_node_num+i*self.block_width+j)
edges_index[1].append(back_node_num-now_node_num+i*self.block_width+j-1)
edges_attr.append(np.array([1,1]).tolist())
edges_index[1].append(back_node_num-now_node_num+i*self.block_width+j)
edges_index[0].append(back_node_num-now_node_num+i*self.block_width+j-1)
edges_attr.append(np.array([1,1]).tolist())
if i>=0 and j+1>=0 and j+1<=self.block_width-1:
edges_index[0].append(back_node_num-now_node_num+i*self.block_width+j)
edges_index[1].append(back_node_num-now_node_num+i*self.block_width+j+1)
edges_attr.append(np.array([1,1]).tolist())
edges_index[1].append(back_node_num-now_node_num+i*self.block_width+j)
edges_index[0].append(back_node_num-now_node_num+i*self.block_width+j+1)
edges_attr.append(np.array([1,1]).tolist())
labels.append(label)
u.append(u_temp)
return tf.cast(self.List2Tensor(nodes),dtype=tf.float32),tf.cast(self.List2Tensor(edges_index),dtype=tf.int32),tf.cast(self.List2Tensor(edges_attr),dtype=tf.float32),tf.cast(self.List2Tensor(u),dtype=tf.float32),tf.cast(self.List2Tensor(batch),dtype=tf.int32),tf.one_hot(labels,36)
"""
batch_size = 3
img_height = 100
img_width = 56
stride = 8
code = Code("../../pytorch_verification_code/dataset/train",batch_size,stride,img_width,img_height)
nodes,edges_index,edges_attr,u,batch,labels = code.next_batch(7)
print(np.array(nodes).shape)
print(np.array(edges_index).shape)
print(np.array(edges_attr).shape)
print(np.array(u).shape)
print(np.array(batch).shape)
print(np.array(labels).shape)
print(labels)
"""