使用图注意力网络求解网格上的 TSP 问题 wu-kan

本文使用图注意力网络求解网格上的 TSP 问题,TSP 问题的规模为 10 个城市。实验中构造了 8000 张图作为训练集,2000 张图作为测试集,并最终在测试集上得到了 80.6% 的 Top-2 准确率。

导言

问题背景

旅行商问题,即 TSP 问题(Traveling Salesman Problem)又译为旅行推销员问题、货郎担问题,是数学领域中著名问题之一。假设有一个旅行商人要拜访 n 个城市,他必须选择所要走的路径,路径的限制是每个城市只能拜访一次,而且最后要回到原来出发的城市。路径的选择目标是要求得的路径路程为所有路径之中的最小值。等价于求图的最短哈密尔顿回路问题。

求解思路

图注意网络层

关于图注意力网络层的相关原理,详见 这一篇博客

建立模型

设计的网络如下。输入 $n \times 2$ 的坐标向量,输出 $n\times n$ 的 0-1 矩阵,对应图中每条边的分类情况。

图注意网络层中注意力是通过顶点上的属性求出来的,这里使用的是标准化后的顶点坐标。网络使用了两个图注意网络层(GAT)。第一层使用 10 head 的 GAT,后接 elu 作为非线性单元;第二层为分类层,后接一个 softmax。

flowchart TB
Input--n*2-->Normalization
Normalization--n*2-->GAT1.1
Normalization--n*2-->GAT1.2
Normalization--n*2-->GAT1.m
subgraph GAT1
GAT1.1
GAT1.2
GAT1.m
end
GAT1.1--n*2-->elu
GAT1.2--n*2-->elu
GAT1.m--n*2-->elu
elu--m*n*2-->GAT2
GAT2--n*n-->softmax
softmax--n*n-->output

实验过程

实验环境

所用机器型号为 VAIO Z Flip 2016。

  • Intel(R) Core(TM) i7-6567U CPU @3.30GHZ 3.31GHz
  • 8.00GB RAM
  • Python 3.8.2 64-bit
    • jupyter==1.0.0
    • numpy==1.18.4
    • torch==1.5.0+cpu
    • torch-scatter==2.0.4
    • torch-sparse==0.6.4
    • torch-cluster==1.5.4
    • torch-geometric==1.5.0
    • jupyter==1.0.0
    • numpy==1.18.4
    • matplotlib==3.2.1

导入相关包和并数据集

此处构造了 10000 个大小为 10 个顶点的无向图,并使用状态压缩动态规划的方法求解出对应的最优解。其中每组数据的顶点坐标相同而边的连接方式不同,因而不同数据有不同的解。同时给出第一个数据的可视化,容易看出构造的数据是正确的。数据集被划分成大小为 8000 的训练集和大小为 2000 的测试集。

由于要保存路径,这里构造单组数据的时间复杂度是 $O(2^nn^3)$,大约每秒可以生成 10~20 组数据。

import io
import math
import random
import copy
from matplotlib import pyplot
import torch
from torch import nn
from torch_geometric.data import InMemoryDataset
from torch_geometric.data import Data

def creat_dataset(num_data=10000,coord_range=100,num_nodes=10):
        data_list = []
        random.seed(0)
        coords = random.sample([[i,j]for i in range(coord_range)for j in range(coord_range)], k=num_nodes)
        flag = 1
        for _i in range(num_data):
                permutation = random.sample(range(num_nodes),k=num_nodes)
                edge_index = [[permutation[i-1],permutation[i]] for i in range(len(permutation))]
                for i in range(num_nodes):
                    for j in range(i):
                        if random.random()<0.5 and [i,j]not in edge_index:
                            edge_index.append([i,j])

                adj=[[math.sqrt(math.pow(coords[i][0]-coords[j][0],2)+math.pow(coords[i][1]-coords[j][1],2))
                                     if [i,j] in edge_index or [j,i] in edge_index  else 2.0*coord_range*num_nodes
                                 for j in range(num_nodes)]for i in range(num_nodes)]

                dp = [[[adj[0][0],[0]]for p in range(len(adj))]for s in range(1<<len(adj))]
                dp[1][0] = [0,[0]]
                for s in range(1<<len(adj)):
                        for v in range(len(adj)):
                            if ((s>>v)&1):
                                for u in range(len(adj)):
                                    if ((s>>u)&1) and dp[s][v][0]>dp[s^(1<<v)][u][0]+adj[u][v]:
                                        dp[s][v]=copy.deepcopy(dp[s^(1<<v)][u])
                                        dp[s][v][0]+=adj[u][v]
                                        dp[s][v][1].append(v)
                ans = [adj[0][0],[0]]
                for i in range(len(adj)):
                    if ans[0] > dp[-1][i][0]+adj[i][0]:
                        ans = [dp[-1][i][0]+adj[i][0],dp[-1][i][1]]

                ans_tour = [[ans[1][i-1],ans[1][i]] for i in range(len(ans[1]))]

                edge_info=[]
                for i in range(len(adj)):
                    for j in range(i):
                        if adj[i][j]!=adj[0][0]:
                            attr=1 if [i,j]in ans_tour or [j,i] in ans_tour else 0
                            edge_info.append([i,j,attr])
                            edge_info.append([j,i,attr])
                if flag:
                    flag=0
                    x=[]
                    y=[]
                    for ed in ans_tour:
                        x.append(coords[ed[0]][0])
                        y.append(coords[ed[0]][1])
                    x.append(x[0])
                    y.append(y[0])
                    pyplot.plot(x,y)
                    pyplot.show()
                    print(ans_tour)
                    print(edge_info)
                pos=coords

                y=[[0 for j in range(len(pos))]for i in range(len(pos))]
                edge_index=[]
                edge_attr=[]
                num_edges=len(edge_info)
                for i in range(num_edges):
                    edge_index.append([edge_info[i][0],edge_info[i][1]])
                    edge_attr.append(edge_info[i][2])
                    if edge_attr[-1]:
                        y[edge_index[-1][0]][edge_index[-1][1]]=1

                x=torch.tensor(pos, dtype=torch.float)
                y=torch.tensor(y, dtype=torch.float)
                edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
                edge_attr = torch.tensor(edge_attr, dtype=torch.float)
                pos=torch.tensor(pos, dtype=torch.float)
                data_list.append(Data(x=x, y=y, edge_index=edge_index,edge_attr=edge_attr,pos=pos))
                print(_i,end='\r')
        return data_list

torch.manual_seed(2020) # seed for reproducible numbers
dataset = creat_dataset()
train_dataset = dataset[0:8000]
test_dataset = dataset[8000:10000]

print(f"Number of Train Graphs :", len(train_dataset))
print(f"Number of Test Graphs :", len(test_dataset))
print(f"Number of Nodes per Graph:", train_dataset[0].num_nodes)
print(f"Number of Node Features:", train_dataset[0].num_node_features)

可视化1

[[3, 0], [0, 6], [6, 8], [8, 4], [4, 5], [5, 9], [9, 1], [1, 7], [7, 2], [2, 3]]
[[1, 0, 0], [0, 1, 0], [2, 0, 0], [0, 2, 0], [3, 0, 1], [0, 3, 1], [3, 2, 1], [2, 3, 1], [4, 0, 0], [0, 4, 0], [4, 2, 0], [2, 4, 0], [5, 0, 0], [0, 5, 0], [5, 2, 0], [2, 5, 0], [5, 4, 1], [4, 5, 1], [6, 0, 1], [0, 6, 1], [7, 1, 1], [1, 7, 1], [7, 2, 1], [2, 7, 1], [7, 4, 0], [4, 7, 0], [7, 6, 0], [6, 7, 0], [8, 1, 0], [1, 8, 0], [8, 3, 0], [3, 8, 0], [8, 4, 1], [4, 8, 1], [8, 6, 1], [6, 8, 1], [8, 7, 0], [7, 8, 0], [9, 0, 0], [0, 9, 0], [9, 1, 1], [1, 9, 1], [9, 2, 0], [2, 9, 0], [9, 3, 0], [3, 9, 0], [9, 5, 1], [5, 9, 1], [9, 8, 0], [8, 9, 0]]
Number of Train Graphs : 8000
Number of Test Graphs : 2000
Number of Nodes per Graph: 10
Number of Node Features: 2

模型

import torch
from torch import nn
from torch_geometric.nn import GATConv

class GAT(nn.Module):
    def __init__(self):
        super(GAT, self).__init__()
        self.hid = 10
        self.in_head = 10
        self.out_head = 1

        self.norm1 = nn.BatchNorm1d(dataset[0].num_node_features)
        self.conv1 = GATConv(dataset[0].num_node_features, self.hid, heads=self.in_head)
        self.conv2 = GATConv(self.hid*self.in_head, dataset[0].num_nodes, concat=False,
                             heads=self.out_head)


    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        # Dropout before the GAT layer is used to avoid overfitting in small datasets like Cora.
        # One can skip them if the dataset is sufficiently large.

        x = self.norm1(x)
        #x = nn.functional.dropout(x, p=0.6, training=self.training)
        x = self.conv1(x, edge_index)
        x = nn.functional.elu(x)
        #x = nn.functional.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)

        return nn.functional.softmax(x, dim=1)

训练模型并评估模型效果

训练了 100 个 epoch,训练时的 batch_size = 100,使用二进制交叉熵作为训练时的 Loss 函数。

单次 epoch 之后使用测试集计算当前的 Top-2 准确度,可以看到 80 个 epoch 之后模型的准确率稳定在 80% 上下。

from torch_geometric.data import DataLoader
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = GAT().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
acc_list=[]
model.train()
for epoch in range(100):
    for data in DataLoader(train_dataset,batch_size=100,shuffle=True):
        model.train()
        optimizer.zero_grad()
        out = model(data)
        loss = nn.functional.binary_cross_entropy(out, data.y)
        loss.backward()
        optimizer.step()
    if epoch%1 == 0:
        model.eval()
        tot=0
        acc=0
        for data in DataLoader(test_dataset):
            out = model(data)
            out = out.detach().numpy().tolist()
            for i in range(len(out)):
                for j in range(len(out[i])):
                    out[i][j]=[out[i][j],j]
                out[i].sort(reverse=True)
                while(len(out[i])>2):
                    out[i].pop()
            y=data.y.numpy().tolist()
            for i in range(len(out)):
                tot+=1
                if y[i][out[i][0][1]] or y[i][out[i][1][1]]:
                    acc+=1
        acc_list.append(acc/tot)
        print(str(epoch)+'\t'+str(acc_list[-1]),end='\r')
    if epoch % 10 == 9:
        torch.save(model, './model.'+str(epoch)+'.pkl')
print(max(acc_list))
pyplot.plot(acc_list)
pyplot.show()
0.80629665

可视化2

总结

直接对边进行分类的方法比较难直接得到 TSP 问题的一个可行解(所得到的边未必能够构成一条路径),但是却反映了每条边作为优解一部分的概率。因此,通过图注意网络进行分类,可以通过已有的数据学习提取出图的一些特征,随后用于后续搜索算法的启发式函数,从而进一步优化已有搜索算法的表现。