지우너

[코드트리] 최소 스패닝 트리7 C++ 본문

Problem Solving

[코드트리] 최소 스패닝 트리7 C++

지옹 2024. 11. 7. 17:20

문제

https://www.codetree.ai/missions/9/problems/minimum-spanning-tree-7?&utm_source=clipboard&utm_medium=text

 

코드트리 | 코딩테스트 준비를 위한 알고리즘 정석

국가대표가 만든 코딩 공부의 가이드북 코딩 왕초보부터 꿈의 직장 코테 합격까지, 국가대표가 엄선한 커리큘럼으로 준비해보세요.

www.codetree.ai

 

코드

#include <iostream>
#include <vector>
#include <queue>

using namespace std;

const int MAX_N = 40'000;

int n, m, mst_sum, max_route, far_node;
vector<pair<int,int> > edges[MAX_N+1]; // edges[a]={b, w}
vector<pair<int,int> > mst_edges[MAX_N+1];
int dist[MAX_N+1];
bool visited[MAX_N+1]={false, };

void initDist(){
    for(int i=0; i<=n; ++i){
        dist[i]=1e9;
    }
}

// 1번 노드에서 시작해서 가장 먼 노드를 찾기
// 가장 먼 노드에서 가장 먼 노드 구하기(트리의 지름)
void dfs(int curr_node, int dist){
    if(visited[curr_node]) return;
    visited[curr_node]=true;
    
    if(dist>max_route){
        max_route=dist;
        far_node=curr_node;
    }
    
    // mst_edges를 사용하도록 수정
    for(auto e : mst_edges[curr_node]){
        int v=e.first, w=e.second;
        dfs(v, dist+w);
    }
}

void prim(){
    priority_queue<pair<int,int>, vector<pair<int,int>>, greater<>> pq; //{weight, node}
    vector<int> parent(n+1, -1);  // 각 노드의 부모 노드를 저장
    
    // 1번 노드에서 시작
    dist[1]=0;
    pq.push({0, 1});
    
    while(!pq.empty()){
        int min_weight = pq.top().first, min_vertex = pq.top().second;
        pq.pop();
        
        if(visited[min_vertex]) continue;
        visited[min_vertex]=true;
        mst_sum+=min_weight;
        
        // parent가 있다면 (시작노드가 아니라면) mst_edges에 추가
        int parent_of_min_vertex = parent[min_vertex];
        if(parent_of_min_vertex != -1) {
            mst_edges[parent_of_min_vertex].push_back({min_vertex, min_weight});
            mst_edges[min_vertex].push_back({parent_of_min_vertex, min_weight});
        }
        
        // min_vertex에 이어진 간선 탐색
        for(auto e: edges[min_vertex]){
            int v=e.first, w=e.second;
            if(!visited[v] && dist[v]>w){
                dist[v]=w;
                parent[v] = min_vertex;  // v의 부모를 min_vertex로 설정
                pq.push({w, v});
            }
        }
    }
}

int main() {
    // [input] 정점의 개수 n, 간선의 개수 m
    cin >> n >> m;
    initDist();
    // [input] m개의 줄에 걸쳐, 각 간선의 양 끝 점과 가중치
    for(int i=0; i<m; ++i){
        int a, b, w;
        cin >> a >> b >> w;
        edges[a].push_back({b, w});
        edges[b].push_back({a, w});
    }

    // [solution] mst 구하기
    fill(visited, visited+(n+1), false);
    prim();

    // [solution] mst의 지름 구하기
    fill(visited, visited+(n+1), false);
    max_route=0;
    dfs(1, 0);
    fill(visited, visited+(n+1), false);
    max_route=0;
    dfs(far_node, 0);

    // [output]
    cout << mst_sum << '\n';
    cout << max_route << '\n';
    return 0;
}

 

[틀렸던 코드]

1. MST를 구하고 -> mst_sum을 구한다

2. MST의 지름을 구하는 문제 -> max_route

(선택된 간선 중 하나를 시작해서 거기서 가장 먼 노드를 찾은 후, 그 노드에서 가장 먼 노드까지의 거리를 구하면 됨)

아래의 코드는 mst가 아닌 트리 전체의 지름을 구하기 때문에 틀렸다. 

mst의 간선을 따로 저장해서 해당 vector로 dfs를 진행하면 맞을 거 같았음.

#include <iostream>
#include <vector>
#include <queue>

using namespace std;

const int MAX_N = 40'000;
int n, m, mst_sum, max_route, far_node;
vector<pair<int,int> > edges[MAX_N+1]; // edges[a]={b, w}
int dist[MAX_N+1];
bool visited[MAX_N+1]={false, };

void initDist(){
    for(int i=0; i<=n; ++i){
        dist[i]=1e9;
    }
}

// 1번 노드에서 시작해서 가장 먼 노드를 찾기
// 가장 먼 노드에서 가장 먼 노드 구하기(트리의 지름)
void dfs(int curr_node, int dist){
    if(visited[curr_node]) return;
    visited[curr_node]=true;

    if(dist>max_route){
        max_route=dist;
        far_node=curr_node;
    }

    // 현재 노드에 연결된 노드 탐색
    for(auto e : edges[curr_node]){
        int v=e.first, w=e.second;
        dfs(v, dist+w);
    }
}

void prim(){
    priority_queue<pair<int,int>, vector<pair<int,int> >, greater<>> pq;//{weight, node}
    
    // 1번 노드에서 시작
    dist[1]=0;
    pq.push({0, 1});

    while(!pq.empty()){
        int min_weight = pq.top().first, min_vertex = pq.top().second;
        pq.pop();

        if(visited[min_vertex]) continue;
        visited[min_vertex]=true;
        mst_sum+=min_weight;

        // min_vertex에 이어진 간선 탐색
        for(auto e: edges[min_vertex]){
            int v=e.first, w=e.second;
            if(!visited[v] && dist[v]>w){
                dist[v]=w;
                pq.push({w, v});
            }
        }
    }

}

int main() {
    // [input] 정점의 개수 n, 간선의 개수 m
    cin >> n >> m;
    initDist();
    // [input] m개의 줄에 걸쳐, 각 간선의 양 끝 점과 가중치
    for(int i=0; i<m; ++i){
        int a, b, w;
        cin >> a >> b >> w;
        edges[a].push_back({b, w});
        edges[b].push_back({a, w});
    }

    // [solution]
    fill(visited, visited+(n+1), false);
    prim();

    // [solution] 트리의 지름 구하기
    fill(visited, visited+(n+1), false);
    dfs(1, 0);
    fill(visited, visited+(n+1), false);
    dfs(far_node, 0);

    // [output]
    cout << mst_sum << '\n';
    cout << max_route << '\n';
    return 0;
}