지우너

[코드트리] 트리의 지름3 C++ 본문

Problem Solving

[코드트리] 트리의 지름3 C++

지옹 2024. 10. 1. 13:10

문제

https://www.codetree.ai/missions/9/problems/diameter-of-tree-3?&utm_source=clipboard&utm_medium=text

 

코드트리 | 코딩테스트 준비를 위한 알고리즘 정석

국가대표가 만든 코딩 공부의 가이드북 코딩 왕초보부터 꿈의 직장 코테 합격까지, 국가대표가 엄선한 커리큘럼으로 준비해보세요.

www.codetree.ai

 

코드

가장 먼 노드 a, b가 있다고 하자(트리의 지름에 해당하는 두 노드).

트리의 지름에 해당하는 최대값을 항상 동일하지만, a->b로 향할 때와 b->a로 향할 때 경로의 값은 달랐다.

그래서 가장 먼 노드에서 각 노드까지의 dist를 구해서 해결할 생각으로 코드를 짰다가 예외가 있다는 걸 알았다(토론 탭의 테스트케이스를 참고하여 알게 됨).

 

그래서 가장 먼 노드1에서 다른 노드들까지의 거리를 dist1에 저장하고, 가장 먼 노드2에서 다른 노드들까지의 거리를 dist2에 저장했다.

dist1에서 지름을 제외한 최장거리와 dist2에서 지름을 제외한 최장거리를 비교하여 더 큰 값을 출력하면 정답이 된다.

 

DFS를 3번이나 하게 되기 때문에 비효율적인 코드인 것 같다.

#include <iostream>
#include <vector>
#include <cstring> // memset
#include <algorithm> //sort
using namespace std;

const int MAX_N = 100001;

int n, far_node, max_dist;
vector<pair<int,int> > edges[MAX_N];
bool visited[MAX_N]={false, };


void DFS(int x, int dist, vector<int> &saveDist){
    // 가장 먼 노드 찾기
    if(dist>max_dist){
        max_dist=dist;
        far_node=x;
    }

    // dist를 벡터에 저장
    saveDist.push_back(dist);

    for(int i=0; i<(int)edges[x].size(); ++i){
        int y=edges[x][i].first, d=edges[x][i].second;
        if(visited[y]) continue;
        visited[y]=true;
        DFS(y, dist+d, saveDist);
    }
}

int main() {
    // input
    cin >> n;
    for(int i=1; i<n; ++i){
        int a, b, d;
        cin >> a >> b >> d;

        // 간선 연결
        edges[a].push_back({b, d});
        edges[b].push_back({a, d});
    }

    // 가장 먼 노드 찾기
    vector<int> tmp;
    visited[1]=true;
    DFS(1, 0, tmp);

    // 가장 먼 노드로부터의 거리 계산
    memset(visited, false, sizeof(visited));
    max_dist=0;
    vector<int> dist1;
    visited[far_node]=true;
    DFS(far_node, 0, dist1);
    
    memset(visited, false, sizeof(visited));
    max_dist=0;
    vector<int> dist2;
    visited[far_node]=true;
    DFS(far_node, 0, dist2);
    
    sort(dist1.begin(), dist1.end());
    sort(dist2.begin(), dist2.end());

    // size-1은 제일 먼 노드와의 거리
    int size = (int)dist1.size();
    cout << max(dist1[size-2], dist2[size-2]) <<'\n';
    return 0;
}