当前位置: 首页 > 工具软件 > Sonnet > 使用案例 >

sonnet_graph_nets

微生俊
2023-12-01

train.py



import sonnet as snt
import tensorflow as tf
from dataloader import *
from models import *
import os


checkpoint_root = "./checkpoints"
checkpoint_name = "model"
save_prefix = os.path.join(checkpoint_root, checkpoint_name)

graph_network = Captcha()
checkpoint = tf.train.Checkpoint(module=graph_network)

latest = tf.train.latest_checkpoint(checkpoint_root)
if latest is not None:
  checkpoint.restore(latest)

code = Code("../pytorch_verification_code/dataset/train",4,8,56,100)
max_iteration = 1000000
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')

for echo in range(max_iteration):
	code.mess_up_order()
	
	for i in range(code.total_number):
		with tf.GradientTape() as gen_tape:
			nodes,edges_index,edges_attr,u,batch,target = code.next_batch(i)
			x, edge_attr, output = graph_network(nodes, edges_index, edges_attr, u, batch)
			loss = tf.nn.softmax_cross_entropy_with_logits(logits=output, labels=target)
			
		gradients_of_generator = gen_tape.gradient(loss, graph_network.trainable_variables)
		generator_optimizer.apply_gradients(zip(gradients_of_generator, graph_network.trainable_variables))

		print('Echo %d,Iter [%d/%d]: train_loss is: %.5f train_accuracy is: %.5f'%(echo+1,i+1,code.total_number,tf.reduce_mean(loss),train_accuracy(target,output)))
		
			
		if i and i % 1000 == 0:
			checkpoint.save(save_prefix)
			
checkpoint.save(save_prefix)

model.py


import sonnet as snt
import tensorflow as tf


class Mish(snt.Module):
    def __init__(self):
        super().__init__()
    def __call__(self, x):
       
        return x * tf.math.tanh(tf.math.softplus(x))

        
class EdgeModel(snt.Module):
	def __init__(self,OUTPUT_EDGE_SIZE):
		super(EdgeModel, self).__init__()
		self.OUTPUT_EDGE_SIZE = OUTPUT_EDGE_SIZE
		
		self.edge_mlp = snt.Sequential([
			snt.Linear(1024),
			Mish(),
			snt.Linear(self.OUTPUT_EDGE_SIZE)
		])
		
		#nodes.shape(1)*2+edge_attr.shape(1)+u.shape(1)
	def __call__(self, src, dest, edge_attr, u, batch):
		# source, target: [E, F_x], where E is the number of edges.
		# edge_attr: [E, F_e]
		# u: [B, F_u], where B is the number of graphs.
		# batch: [E] with max entry B - 1.
		out = tf.concat([src, dest, edge_attr, tf.gather(u, batch.numpy())], 1)
		return self.edge_mlp(out)

class NodeModel(snt.Module):
	def __init__(self,OUTPUT_NODE_SIZE):
		super(NodeModel, self).__init__()
		self.OUTPUT_NODE_SIZE = OUTPUT_NODE_SIZE
		
		self.node_mlp_1 = snt.Sequential([
			snt.Linear(1024),
			Mish(),
			snt.Linear(self.OUTPUT_NODE_SIZE)
		])
		
		self.node_mlp_2 = snt.Sequential([
			snt.Linear(1024),
			Mish(),
			snt.Linear(self.OUTPUT_NODE_SIZE)
		])
		
		#nodes.shape(1)+edge_attr.shape(1)
		#nodes.shape(1)*2+u.shape(1)
	def __call__(self, x, edge_index, edge_attr, u, batch):
		# x: [N, F_x], where N is the number of nodes.
		# edge_index: [2, E] with max entry N - 1.
		# edge_attr: [E, F_e]
		# u: [B, F_u]
		# batch: [N] with max entry B - 1.
		row, col = edge_index
		out = tf.concat([tf.gather(x, row.numpy()), edge_attr], 1)
		out = self.node_mlp_1(out)
		out = tf.compat.v2.math.unsorted_segment_mean(out, col, num_segments=x.shape[0])
		out = tf.concat([x, out, tf.squeeze(tf.gather(u, batch.numpy()))], 1)
		return self.node_mlp_2(out)

class GlobalModel(snt.Module):
	def __init__(self,OUTPUT_GLOBAL_SIZE):
		super(GlobalModel, self).__init__()
		self.OUTPUT_GLOBAL_SIZE = OUTPUT_GLOBAL_SIZE
		
		self.global_mlp = snt.Sequential([
			snt.Linear(1024),
			Mish(),
			snt.Linear(self.OUTPUT_GLOBAL_SIZE)
		])
		#u.shape(1)+nodes.shape(1)
	def __call__(self, x, edge_index, edge_attr, u, batch):
		# x: [N, F_x], where N is the number of nodes.
		# edge_index: [2, E] with max entry N - 1.
		# edge_attr: [E, F_e]
		# u: [B, F_u]
		# batch: [N] with max entry B - 1.
		out = tf.concat([u, tf.compat.v2.math.unsorted_segment_mean(x, tf.squeeze(batch), num_segments=u.shape[0])], 1)
		return self.global_mlp(out)
		
class GraphNetwork(snt.Module):
    
    def __init__(self, edge_model=None, node_model=None, global_model=None):
        super(GraphNetwork, self).__init__()
        self.edge_model = edge_model
        self.node_model = node_model
        self.global_model = global_model

    def __call__(self, x, edge_index, edge_attr=None, u=None, batch=None):
        
        row, col = edge_index
        #print(tf.gather(batch, row.numpy()))
        if self.edge_model is not None:
            edge_attr = self.edge_model(tf.gather(x, row.numpy()), tf.gather(x, col.numpy()), edge_attr, u,
                                        batch if batch is None else tf.squeeze(tf.gather(batch, row.numpy())))

        if self.node_model is not None:
            x = self.node_model(x, edge_index, edge_attr, u, batch)

        if self.global_model is not None:
            u = self.global_model(x, edge_index, edge_attr, u, batch)

        return x, edge_attr, u


class Captcha(snt.Module):
	def __init__(self):
		super(Captcha,self).__init__()
		self.GN_1 = GraphNetwork(EdgeModel(32), NodeModel(96), GlobalModel(1024))
		self.GN_2 = GraphNetwork(EdgeModel(16), NodeModel(48), GlobalModel(512))
		self.GN_3 = GraphNetwork(EdgeModel(8), NodeModel(24), GlobalModel(256))
		self.GN_4 = GraphNetwork(EdgeModel(4), NodeModel(12), GlobalModel(64))
		self.GN_5 = GraphNetwork(EdgeModel(2), NodeModel(6), GlobalModel(36))
		
	def __call__(self, x, edge_index, edge_attr, u, batch):

		x_, edge_attr_, u_ = self.GN_1(x, edge_index, edge_attr, u, batch)
		x_, edge_attr_, u_ = self.GN_2(x_, edge_index, edge_attr_, u_, batch)
		x_, edge_attr_, u_ = self.GN_3(x_, edge_index, edge_attr_, u_, batch)
		x_, edge_attr_, u_ = self.GN_4(x_, edge_index, edge_attr_, u_, batch)
		x_, edge_attr_, u_ = self.GN_5(x_, edge_index, edge_attr_, u_, batch)
		
		return x_, edge_attr_, u_
		

code.py



import tensorflow as tf

import numpy as np
from pathlib import Path
import cv2
import random

class Code:
	def __init__(self,path_kind,batch_size,stride,img_width,img_height):
		self.alpha = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9','a','b',
					  'c','d','e','f','g','h','i','j','k','l','m','n','o','p',
					  'q','r','s','t','u','v','w','x','y','z']
					  
		self.data_root = Path(path_kind)
		self.batch_size = batch_size
		self.img_height = img_height
		self.img_width = img_width
		self.stride = stride
		
		self.block_height = int(img_height / stride)
		self.block_width = int(img_width / stride)
		
		self.load()
		
	def load(self):
		
		self.second_image_paths = list(self.data_root.glob('*'))
		self.second_image_paths=[str(path) for path in self.second_image_paths]
		self.total_number = len(self.second_image_paths) // self.batch_size
	
	
	def mess_up_order(self):
		random.shuffle(self.second_image_paths)
	
	def List2Tensor(self,x):
		return tf.reshape(x, [tf.shape(x)[0], -1])
		
	def next_batch(self,index):
		nodes = []
		edges_index = [[],[]]
		edges_attr = []
		u = []
		labels = []
		batch = []
		back_node_num = 0
		for k,path in enumerate(self.second_image_paths[self.batch_size*index:self.batch_size*(index+1)]) : 
			now_node_num = self.block_height * self.block_width
			back_node_num = back_node_num + now_node_num
			temp_str = path.split('/')[-1]
			begin=temp_str.find('_')
			end=temp_str.find('.')
			label = self.alpha.index(temp_str[begin+2:end])
			
			image = cv2.imread(path)
			image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
			
			image = ((image / 255.) - .5) * 2.
			
			u_temp = []
			
			batch = batch +[k]*now_node_num
			for i in range(self.block_height):
				for j in range(self.block_width):
					#nodes
					temp_nodes = image[i*self.stride:(i+1)*self.stride,j*self.stride:(j+1)*self.stride].flatten()
					u_temp.append(np.mean(temp_nodes))
					nodes.append(temp_nodes.tolist())
					
					#edges
					if i-1>=0 and j>=0:
						edges_index[0].append(back_node_num-now_node_num+i*self.block_width+j)
						edges_index[1].append(back_node_num-now_node_num+(i-1)*self.block_width+j)
						edges_attr.append(np.array([1,1]).tolist())
						
						edges_index[1].append(back_node_num-now_node_num+i*self.block_width+j)
						edges_index[0].append(back_node_num-now_node_num+(i-1)*self.block_width+j)
						edges_attr.append(np.array([1,1]).tolist())
						
						
					if i+1>=0 and j>=0 and i+1<=self.block_height-1:
						edges_index[0].append(back_node_num-now_node_num+i*self.block_width+j)
						edges_index[1].append(back_node_num-now_node_num+(i+1)*self.block_width+j)
						edges_attr.append(np.array([1,1]).tolist())
						
						edges_index[1].append(back_node_num-now_node_num+i*self.block_width+j)
						edges_index[0].append(back_node_num-now_node_num+(i+1)*self.block_width+j)
						edges_attr.append(np.array([1,1]).tolist())
						
					if i>=0 and j-1>=0:
						edges_index[0].append(back_node_num-now_node_num+i*self.block_width+j)
						edges_index[1].append(back_node_num-now_node_num+i*self.block_width+j-1)
						edges_attr.append(np.array([1,1]).tolist())
						
						edges_index[1].append(back_node_num-now_node_num+i*self.block_width+j)
						edges_index[0].append(back_node_num-now_node_num+i*self.block_width+j-1)
						edges_attr.append(np.array([1,1]).tolist())
						
					if i>=0 and j+1>=0 and j+1<=self.block_width-1:
						edges_index[0].append(back_node_num-now_node_num+i*self.block_width+j)
						edges_index[1].append(back_node_num-now_node_num+i*self.block_width+j+1)
						edges_attr.append(np.array([1,1]).tolist())
						
						edges_index[1].append(back_node_num-now_node_num+i*self.block_width+j)
						edges_index[0].append(back_node_num-now_node_num+i*self.block_width+j+1)
						edges_attr.append(np.array([1,1]).tolist())
					
			labels.append(label)
			u.append(u_temp)
		return tf.cast(self.List2Tensor(nodes),dtype=tf.float32),tf.cast(self.List2Tensor(edges_index),dtype=tf.int32),tf.cast(self.List2Tensor(edges_attr),dtype=tf.float32),tf.cast(self.List2Tensor(u),dtype=tf.float32),tf.cast(self.List2Tensor(batch),dtype=tf.int32),tf.one_hot(labels,36)

"""

batch_size = 3
img_height = 100
img_width = 56
stride = 8

code = Code("../../pytorch_verification_code/dataset/train",batch_size,stride,img_width,img_height)

nodes,edges_index,edges_attr,u,batch,labels = code.next_batch(7)


print(np.array(nodes).shape)
print(np.array(edges_index).shape)
print(np.array(edges_attr).shape)
print(np.array(u).shape)
print(np.array(batch).shape)
print(np.array(labels).shape)
print(labels)
"""

github:https://github.com/coolsunxu/sonnet_graph_nets

 类似资料:

相关阅读

相关文章

相关问答