1584. Min Cost to Connect All Points
背景知识
这道题本质上是一道最小生成树 Minimum Spanning Tree (MST) 的题目。
题目给了我们在二维平面上的 n 个点,让我们求能让总距离最短且能将所有点连接起来的路径。
这个过程,恰巧也是生成 MST 的过程。
求 MST,主要有两个算法:Kruskal 与 Prim,下面将会分别描述两个算法的思路
解题思路——Kruskal's Algorithm
Kruskal 算法把焦点放在 “边” 上,他的想法是:
- 先求出所有的边,再对它们的权重(一般是距离)排序
- 根据排序,从小到大依次取边
- 如果取得的边会让图形构成 “回路”,就抛弃这条边,继续第二步
- 如果不会则持续第二步,直到所有边都被取完或被丢弃
里面值得推敲的是,要如何判断取得的边会构成 “回路”,这一点就要用到并查集 Union-find, aka Disjoint-set。
简单来说,并查集是一个数据解构,它提供了两种方法:
- 合并 Union:将一个节点合并到一个组中
- 查 Find:查找一个节点是否在一个组中
我们可以利用它的特性,经过以下步骤确认是否成环:
- 把第一个点加到组里
- 取得选中的边的终点,并且确认是否已经在组里了
- 如果已经在组里了,那么表示加入这个边必定成环,则丢弃
- 如果不在组里,表示选中的边不会成环,把它的终点加到组里,然后重复第二步直到所有的边都取完了
C++
class UnionFind {
private:
vector<int> group;
public:
UnionFind(int size) {
group = vector<int>(size);
for (int i = 0; i < size; i++) {
group[i] = i;
}
}
int find(int node) {
if (group[node] != node) {
group[node] = find(group[node]);
}
return group[node];
}
void merge(int node1, int node2) {
int group1 = find(node1);
int group2 = find(node2);
// already in 1 group
if (group1 == group2) {
return;
}
group[group1] = group2;
}
bool connected(int node1, int node2) { return find(node1) == find(node2); }
};
class Solution {
public:
int minCostConnectPoints(vector<vector<int>>& points) {
int n = points.size();
vector<pair<int, pair<int, int>>> edges;
// store edges' distance
for (int curNode = 0; curNode < n; curNode++) {
for (int nextNode = curNode + 1; nextNode < n; nextNode++) {
int distance = abs(points[curNode][0] - points[nextNode][0]) +
abs(points[curNode][1] - points[nextNode][1]);
edges.push_back({distance, {curNode, nextNode}});
}
}
// sort by distance
sort(edges.begin(), edges.end());
UnionFind uf(n);
int result = 0;
for (auto& edge : edges) {
int node1 = edge.second.first;
int node2 = edge.second.second;
if (uf.connected(node1, node2)) {
continue;
}
result += edge.first;
uf.merge(node1, node2);
}
return result;
}
};
解题思路——Prim's Algorithm
相比 Kruskal 把重心放在边上,Prim 算法把重心放在 “点” 上。
它的核心思路在于:
- 一开始选中任意一点,将其加入 MST 中
- 选中距离 MST 最近的一个点,将其也加入 MST 中
- 重复第二步直到所有的点都被纳入 MST 中
使用一般的 Prim 算法的性能瓶颈在于:要维护一个优先队列,用以保存所有与 MST 相邻的边长度。
实现如下:
/*
* @lc app=leetcode id=1584 lang=cpp
*
* [1584] Min Cost to Connect All Points
*/
struct Edge {
int start;
int target;
int distance;
Edge(int _start, int _target, int _distance) {
start = _start;
target = _target;
distance = _distance;
}
};
struct Cmp {
bool operator()(Edge &a, Edge &b) { return a.distance > b.distance; }
};
// @lc code=start
class Solution {
private:
// visited points
vector<int> visited;
// reserve shortest edge at top
priority_queue<Edge, vector<Edge>, Cmp> pq;
// find all edges of the point, and push into pq
void findEdge(int point, vector<vector<pair<int, int>>> &graph) {
for (auto &edge : graph[point]) {
int target = edge.first;
// if visited, forget it
if (visited[target]) {
continue;
}
int distance = edge.second;
pq.push({point, target, distance});
}
}
public:
int minCostConnectPoints(vector<vector<int>> &points) {
int n = points.size();
visited.resize(n);
vector<vector<pair<int, int>>> graph(n);
// reserve edges' distance
for (int curNode = 0; curNode < n; curNode++) {
for (int nextNode = curNode + 1; nextNode < n; nextNode++) {
int distance = abs(points[curNode][0] - points[nextNode][0]) +
abs(points[curNode][1] - points[nextNode][1]);
graph[curNode].push_back({nextNode, distance});
graph[nextNode].push_back({curNode, distance});
}
}
int result = 0;
visited[0] = 1;
findEdge(0, graph);
while (!pq.empty()) {
Edge edge = pq.top();
pq.pop();
int target = edge.target;
if(visited[target]) {
continue;
}
result += edge.distance;
visited[target] = 1;
findEdge(target, graph);
}
return result;
}
};
// @lc code=end
所以,我们可以换一种思路,优化一下优先队列。
具体地说,我们可以改为维护一个数组 minDist
,它是一个长度等于节点数的数组,minDist[n]
表示从节点 n 到 MST 树的最短距离。
具体步骤如下:
- 在一开始,我们先把
minDist
的每一项都设为∞ - 然后我们任意选择一点
- 将选择的点加入 MST
- 然后将
minDist[n]
设为 0(n 为选择的点),将其他元素更新成该点到点 n 的距离 - 从更新好的距离中选择还没被加入 MST 且距离最短的点,重复步骤三
实现如下:
class Solution {
public:
int minCostConnectPoints(vector<vector<int>> &points) {
int n = points.size();
int result = 0;
int edgeUsed = 0;
vector<bool> visited(n);
// minDist[n]: the shortest distance from n to MST
vector<int> minDist(n, INT_MAX);
// start from points[0]
minDist[0] = 0;
while (edgeUsed < n) {
int curMinEdge = INT_MAX;
int curNode = -1;
for (int i = 0; i < n; i++) {
if (!visited[i] && curMinEdge > minDist[i]) {
curMinEdge = minDist[i];
curNode = i;
}
}
visited[curNode] = true;
result += curMinEdge;
edgeUsed++;
// update minDist for adjacent node
for (int nextNode = 0; nextNode < n; nextNode++) {
int distance = abs(points[curNode][0] - points[nextNode][0]) +
abs(points[curNode][1] - points[nextNode][1]);
if (!visited[nextNode] && distance < minDist[nextNode]) {
minDist[nextNode] = distance;
}
}
}
return result;
}
};
如果还有看不懂的同学可以看一下 官方解答 的第三个的动图就懂了