Algorithm
백준
Python
최소 스패닝 트리

백준 1197 - 최소 신장 트리

https://www.acmicpc.net/problem/1197 (opens in a new tab)

내 풀이

import sys
import heapq
 
v,e = map(int, sys.stdin.readline().split())
times = [[] for _ in range(v + 1)]
all_v = []
 
for _ in range(e) :
  start, end, cost = map(int, sys.stdin.readline().split())
  times[start].append([cost, end])
  if start is not all_v :
    all_v.append(start)
 
def dijkstra(graph, start, all_v) :
  costs = {}
  pq = []
  heapq.heappush(pq, [0, start])
  while pq :
    cur_cost, cur_v = heapq.heappop(pq)
    if cur_v not in costs :
      costs[cur_v] = cur_cost
      for cost, next_v in graph[cur_v] :
        next_cost = cur_cost + cost
        heapq.heappush(pq, [next_cost, next_v])
 
  cost_keys = costs.keys()
  for v in all_v :
    if v not in cost_keys :
      return -1
 
  return max(cost_keys)
 
min_value = 1000000
for start_v in all_v :
  result = dijkstra(times, start_v, all_v)
  if result == -1 :
    continue
  else :
    min_value = min(result, min_value)
 
print(min_value)

Java 풀이

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.PriorityQueue;
import java.util.StringTokenizer;
 
/**
 * 참고 : https://velog.io/@jodawooooon/Java-BOJ-1197-%EC%B5%9C%EC%86%8C-%EC%8A%A4%ED%8C%A8%EB%8B%9D-%ED%8A%B8%EB%A6%AC-MST
 * 풀지 못한 이유 : Comparable을 Implements 해서 Node Class를 직접 만들어야 한다는걸 몰랐음 (Python 처럼 List 넣으면 0번 index 기준으로 정렬될줄
 */
public class Main {
	static class Node implements Comparable<Node> {
		int to;
		int cost;
 
		public Node(int to, int cost) {
			super();
			this.to = to;
			this.cost = cost;
		}
 
    /**
       	•	this.cost - o.cost가 음수일 때: this 객체가 o 객체보다 작다고 판단되어 this 객체가 먼저 오게 됩니다. 즉, 오름차순으로 정렬됩니다.
	      •	this.cost - o.cost가 양수일 때: this 객체가 o 객체보다 크다고 판단되어 this 객체가 나중에 오게 됩니다.
	      •	this.cost - o.cost가 0일 때: 두 객체의 순서는 동일하다고 판단됩니다.
     */
		@Override
		public int compareTo(Node o) {
			return this.cost - o.cost;
		}
	}
	static int V, E, ans;
	static boolean[] visited;
	static ArrayList<Node>[] nodeList;
	public static void main(String[] args) throws Exception {
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		StringTokenizer st = new StringTokenizer(br.readLine());
 
		V = Integer.parseInt(st.nextToken());
		E = Integer.parseInt(st.nextToken());
 
		visited = new boolean[V+1];
		nodeList = new ArrayList[V+1];
		for (int i = 1; i <= V; i++) {
			nodeList[i] = new ArrayList<>();
		}
 
		for (int i = 0; i < E; i++) {
			st = new StringTokenizer(br.readLine());
 
			int from = Integer.parseInt(st.nextToken());
			int to = Integer.parseInt(st.nextToken());
			int cost = Integer.parseInt(st.nextToken());
 
			//from번 정점과 to번 정점이 가중치 cost인 간선으로 연결되어 있다
			nodeList[from].add(new Node(to, cost));
			nodeList[to].add(new Node(from, cost));
		}
 
		PriorityQueue<Node> pq = new PriorityQueue<>();
		pq.add(new Node(1,0));
 
		while(!pq.isEmpty()){
			Node n = pq.poll();
			int to = n.to;
			int cost = n.cost;
 
			if(visited[to]) continue;
			visited[to] = true;
			ans += cost;
 
			for(Node next : nodeList[to]) {
				if(!visited[next.to]) pq.add(next);
			}
		}
		System.out.println(ans);
	}
}

풀이

그냥 크루스칼 알고리즘이나 프림 알고리즘을 사용하면 풀리는 기본적인 문제다.

import heapq
import sys
 
sys.setrecursionlimit(10 ** 6)
input = sys.stdin.readline
 
n, m = map(int, input().split())  # 노드 수, 간선 수
graph = [[] for _ in range(n + 1)]
visited = [0] * (n + 1)  # 노드의 방문 정보 초기화
 
# 무방향 그래프 생성
for i in range(m):  # 간성 정보 입력 받기
  start, end, weight = map(int, input().split())
  graph[start].append([weight, start, end])
  graph[end].append([weight, end, start])
 
# 프림 알고리즘
def prim(graph, start_node):
  visited[start_node] = 1  # 방문 갱신
  pq = graph[start_node]  # 인접 간선 추출
  heapq.heapify(pq)  # 우선순위 큐 생성
  total_cost = 0  # 전체 가중치
 
  while pq:
    weight, start, end = heapq.heappop(pq)  # 가중치가 가장 적은 간선 추출
    if visited[end] == 0:  # end가 방문하지 않았다면
      visited[end] = 1  # end 방문 갱신
      total_cost += weight  # 전체 가중치 갱신
 
      for edge in graph[end]:  # end가 start로 감
        if visited[edge[2]] == 0:  # edge[2]는 다음번 방문할 end
          heapq.heappush(pq, edge)
 
  return total_cost
 
print(prim(graph, 1))