1584. Min Cost to Connect All Points

Leetcode link

背景知识

这道题本质上是一道最小生成树 Minimum Spanning Tree (MST) 的题目。

题目给了我们在二维平面上的 n 个点,让我们求能让总距离最短且能将所有点连接起来的路径。

这个过程,恰巧也是生成 MST 的过程。

求 MST,主要有两个算法:Kruskal 与 Prim,下面将会分别描述两个算法的思路

解题思路——Kruskal's Algorithm

Kruskal 算法把焦点放在 “边” 上,他的想法是:

  1. 先求出所有的边,再对它们的权重(一般是距离)排序
  2. 根据排序,从小到大依次取边
  3. 如果取得的边会让图形构成 “回路”,就抛弃这条边,继续第二步
  4. 如果不会则持续第二步,直到所有边都被取完或被丢弃


里面值得推敲的是,要如何判断取得的边会构成 “回路”,这一点就要用到并查集 Union-find, aka Disjoint-set

简单来说,并查集是一个数据解构,它提供了两种方法:

  • 合并 Union:将一个节点合并到一个组中
  • 查 Find:查找一个节点是否在一个组中

我们可以利用它的特性,经过以下步骤确认是否成环:

  1. 把第一个点加到组里
  2. 取得选中的边的终点,并且确认是否已经在组里了
  3. 如果已经在组里了,那么表示加入这个边必定成环,则丢弃
  4. 如果不在组里,表示选中的边不会成环,把它的终点加到组里,然后重复第二步直到所有的边都取完了

C++

class UnionFind {
 private:
  vector<int> group;

 public:
  UnionFind(int size) {
    group = vector<int>(size);
    for (int i = 0; i < size; i++) {
      group[i] = i;
    }
  }

  int find(int node) {
    if (group[node] != node) {
      group[node] = find(group[node]);
    }
    return group[node];
  }

  void merge(int node1, int node2) {
    int group1 = find(node1);
    int group2 = find(node2);

    // already in 1 group
    if (group1 == group2) {
      return;
    }
    group[group1] = group2;
  }

  bool connected(int node1, int node2) { return find(node1) == find(node2); }
};

class Solution {
 public:
  int minCostConnectPoints(vector<vector<int>>& points) {
    int n = points.size();
    vector<pair<int, pair<int, int>>> edges;

    // store edges' distance
    for (int curNode = 0; curNode < n; curNode++) {
      for (int nextNode = curNode + 1; nextNode < n; nextNode++) {
        int distance = abs(points[curNode][0] - points[nextNode][0]) +
                       abs(points[curNode][1] - points[nextNode][1]);
        edges.push_back({distance, {curNode, nextNode}});
      }
    }

    // sort by distance
    sort(edges.begin(), edges.end());

    UnionFind uf(n);
    int result = 0;
    for (auto& edge : edges) {
      int node1 = edge.second.first;
      int node2 = edge.second.second;
      if (uf.connected(node1, node2)) {
        continue;
      }
      result += edge.first;
      uf.merge(node1, node2);
    }
    return result;
  }
};

解题思路——Prim's Algorithm

相比 Kruskal 把重心放在边上,Prim 算法把重心放在 “点” 上。

它的核心思路在于:

  1. 一开始选中任意一点,将其加入 MST 中
  2. 选中距离 MST 最近的一个点,将其也加入 MST 中
  3. 重复第二步直到所有的点都被纳入 MST 中


使用一般的 Prim 算法的性能瓶颈在于:要维护一个优先队列,用以保存所有与 MST 相邻的边长度。

实现如下:

/*
 * @lc app=leetcode id=1584 lang=cpp
 *
 * [1584] Min Cost to Connect All Points
 */
struct Edge {
  int start;
  int target;
  int distance;
  Edge(int _start, int _target, int _distance) {
    start = _start;
    target = _target;
    distance = _distance;
  }
};

struct Cmp {
  bool operator()(Edge &a, Edge &b) { return a.distance > b.distance; }
};

// @lc code=start
class Solution {
 private:
  // visited points
  vector<int> visited;
  // reserve shortest edge at top
  priority_queue<Edge, vector<Edge>, Cmp> pq;
  // find all edges of the point, and push into pq
  void findEdge(int point, vector<vector<pair<int, int>>> &graph) {
    for (auto &edge : graph[point]) {
      int target = edge.first;
      // if visited, forget it
      if (visited[target]) {
        continue;
      }
      int distance = edge.second;
      pq.push({point, target, distance});
    }
  }

 public:
  int minCostConnectPoints(vector<vector<int>> &points) {
    int n = points.size();
    visited.resize(n);
    vector<vector<pair<int, int>>> graph(n);

    // reserve edges' distance
    for (int curNode = 0; curNode < n; curNode++) {
      for (int nextNode = curNode + 1; nextNode < n; nextNode++) {
        int distance = abs(points[curNode][0] - points[nextNode][0]) +
                       abs(points[curNode][1] - points[nextNode][1]);
        graph[curNode].push_back({nextNode, distance});
        graph[nextNode].push_back({curNode, distance});
      }
    }

    int result = 0;
    visited[0] = 1;
    findEdge(0, graph);
    while (!pq.empty()) {
      Edge edge = pq.top();
      pq.pop();
      int target = edge.target;
      if(visited[target]) {
        continue;
      }
      result += edge.distance;
      visited[target] = 1;
      findEdge(target, graph);
    }
    return result;
  }
};
// @lc code=end


所以,我们可以换一种思路,优化一下优先队列。

具体地说,我们可以改为维护一个数组 minDist,它是一个长度等于节点数的数组,minDist[n] 表示从节点 n 到 MST 树的最短距离。

具体步骤如下:

  1. 在一开始,我们先把 minDist 的每一项都设为∞
  2. 然后我们任意选择一点
  3. 将选择的点加入 MST
  4. 然后将 minDist[n] 设为 0(n 为选择的点),将其他元素更新成该点到点 n 的距离
  5. 从更新好的距离中选择还没被加入 MST 且距离最短的点,重复步骤三

实现如下:

class Solution {
 public:
  int minCostConnectPoints(vector<vector<int>> &points) {
    int n = points.size();
    int result = 0;
    int edgeUsed = 0;

    vector<bool> visited(n);
    // minDist[n]: the shortest distance from n to MST
    vector<int> minDist(n, INT_MAX);
    // start from points[0]
    minDist[0] = 0;

    while (edgeUsed < n) {
      int curMinEdge = INT_MAX;
      int curNode = -1;

      for (int i = 0; i < n; i++) {
        if (!visited[i] && curMinEdge > minDist[i]) {
          curMinEdge = minDist[i];
          curNode = i;
        }
      }

      visited[curNode] = true;
      result += curMinEdge;
      edgeUsed++;

      // update minDist for adjacent node
      for (int nextNode = 0; nextNode < n; nextNode++) {
        int distance = abs(points[curNode][0] - points[nextNode][0]) +
                       abs(points[curNode][1] - points[nextNode][1]);
        if (!visited[nextNode] && distance < minDist[nextNode]) {
          minDist[nextNode] = distance;
        }
      }
    }

    return result;
  }
};

如果还有看不懂的同学可以看一下 官方解答 的第三个的动图就懂了

results matching ""

    No results matching ""