[백준] 21924: 도시 건설 - JAVA

https://www.acmicpc.net/problem/21924

풀이

도로를 연결해야 한다. 최소한의 비용으로.

즉, 최소 스패닝 트리(MST)에 대한 문제이다.

크루스칼과 프림의 두 가지 알고리즘으로 풀 수 있는데 크루스칼 알고리즘을 선택했다.

크루스칼 알고리즘은 간선을 하나씩 선택해서 최소신장트리를 찾는 알고리즘이다.

1. 처음에 모든 간선을 가중치에 따라 오름차순으로 정렬하고 시작한다.

2. 가중치가 가장 낮은 간선부터 선택하면서 트리를 증가시킨다.(이때, 사이클이 존재하게 될 경우 다음으로 가중치가 낮은 간선을 선택한다.)

3. N-1개의 간선이 선택될때까지 2번 과정을 반복한다.

간선을 비용에 따라 정렬해야 하는데 정렬하는 방법으로 PriorityQueue를 이용했다.

크루스칼 알고리즘을 수행하는 메서드를 하나 따로 만들어서

N-1개가 선택되면 값을 반환해줬고

N-1개가 선택되지 않을 경우 모든 건물이 연결이 되지 않았다는 뜻으로 -1을 반환해줬다.


메모리: 162860KB

시간: 1152ms

언어: Java 11

코드

import java.io.*;
import java.util.*;

public class Main {
    static class Node implements Comparable<Node> {
        int a;
        int b;
        int cost;

        public Node(int a, int b, int cost) {
            this.a = a;
            this.b = b;
            this.cost = cost;
        }

        @Override
        public int compareTo(Node o) {
            return this.cost - o.cost;
        }
    }

    static int N, M;
    static int[] parent;
    static long sum;

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        N = Integer.parseInt(st.nextToken());
        M = Integer.parseInt(st.nextToken());

        sum = 0;
        PriorityQueue<Node> pq = new PriorityQueue<>();
        for (int i = 0; i < M; i++) {
            st = new StringTokenizer(br.readLine());
            int a = Integer.parseInt(st.nextToken());
            int b = Integer.parseInt(st.nextToken());
            int c = Integer.parseInt(st.nextToken());

            sum += c;
            pq.add(new Node(a, b, c));
        }

        parent = new int[N + 1];
        for (int i = 0; i < N + 1; i++) {
            parent[i] = i;
        }

        System.out.println(doKruskal(pq));
    }

    private static long doKruskal(PriorityQueue<Node> pq) {
        long total = 0;
        int cnt = 0;

        while (!pq.isEmpty()) {
            Node node = pq.poll();

            int px = findSet(node.a);
            int py = findSet(node.b);

            if (px != py) {
                total += node.cost;
                union(node.a, node.b);
                cnt++;
            }

            if (cnt == N - 1) {
                return sum - total;
            }
        }

        return -1;
    }

    private static void union(int a, int b) {
        int px = findSet(a);
        int py = findSet(b);

        if (px > py) {
            parent[py] = px;
        } else {
            parent[px] = py;
        }
    }

    private static int findSet(int x) {
        if (x == parent[x]) {
            return x;
        } else {
            return parent[x] = findSet(parent[x]);
        }
    }

}