지우너

[코드트리] 커지는 간선의 값 C++ 본문

Problem Solving

[코드트리] 커지는 간선의 값 C++

지옹 2024. 11. 4. 13:21

문제

https://www.codetree.ai/missions/9/problems/growing-edge-value?&utm_source=clipboard&utm_medium=text

 

코드트리 | 코딩테스트 준비를 위한 알고리즘 정석

국가대표가 만든 코딩 공부의 가이드북 코딩 왕초보부터 꿈의 직장 코테 합격까지, 국가대표가 엄선한 커리큘럼으로 준비해보세요.

www.codetree.ai

 

코드

#include <iostream>
#include <vector>
#include <queue>

using namespace std;

const int MAX_N = 100'000;

int n, m, k;
vector<pair<int,int> > edges[MAX_N]; // edges[i]: i번 노드에 연결된 {노드, 가중치}
priority_queue<pair<int, int>, vector<pair<int,int> >, greater<>> pq; // {weight, vertex}
int dist[MAX_N+1]; // dist[x]: 현재까지 만들어진 MST와 노드 x를 연결하기 위해 필요한 최소 비용
bool visited[MAX_N+1]={false, };

void InitDist(){
    for(int i=1; i<=n; ++i){
        dist[i]=1e9;
    }
    // 시작점(아무 정점이나 선택 가능) 0으로 초기화
    dist[1]=0;
}

int main() {
    cin >> n >> m >> k;
    for(int i=0; i<m; ++i){
        int a, b, w;
        cin >> a >> b >> w;
        edges[a].push_back({b, w});
        edges[b].push_back({a, w});
    }

    // dist 배열을 초기화(INF), 출발지(1)의 값 0
    InitDist();
    
    // 거리 dist 내의 값들 중 최솟값 선택(우선순위 큐를 사용)
        // 다익스트라와 마찬가지로 프림 알고리즘에서도
        // 최솟값을 골라주는 과정을 여러 번 반복하기 때문
    long long answer=0;
    int cnt=0;
    // 시작점을 queue에 넣어줌 {weight, vertex}
    pq.push({0, 1});
    while(!pq.empty()){
        // 가장 거리가 가까운 정점의 정보
        int min_dist=pq.top().first, min_vertex=pq.top().second;
        pq.pop();
        
        // 이미 방문한 노드라면 무시
        if (visited[min_vertex]) continue;

        // 최소값 방문 표시
        visited[min_vertex]=true;
        answer += min_dist+(k*cnt++);
        // graph[min_vertex]에 연결된 정점들 거리 갱신
        for(auto e: edges[min_vertex]){
            int vertex = e.first, weight = e.second;

            if(!visited[vertex] && dist[vertex]>weight){
                dist[vertex] = weight;
                pq.push({weight, vertex});
            }
        }
    }
    answer-=k*(cnt-1);
    cout << answer << '\n';
    return 0;
}

 

위와 같이 풀긴 했는데, 아래의 방법이 더 나은 방법으로 보임

minweight1+(k*0)+minweight2+(k*1)+minweight3+(k*2)+...+minweightN+(k*(N-1))

=minweight1+minweight2+minweight3+...+minvertexN

+(k*0)+(k*1)+(k*2)+...+(k*(N-1))

 

위 식에서 k로 묶어주면 아래와 같아진다.

=minweight1+minweight2+minweight3+...+minweightN

+k(0+1+2+...+(N-1))

0부터 n까지의 합은 n(n+1)/2이다. 예를 들어 0부터 10까지의 합은 10*11/2=55

 

우리는 n개의 노드, n-1개의 간선을 고를 것이다.

간선을 n-1개 고르면 가중치는 n-2번 오른다.

따라서 n(n+1)/2에 n-2를 대입하면 0부터 n-2까지의 합은, (n-2)*(n-1)/2가 된다.

 

결과적으로 minweight1+minweight2+minweight3+...+minweight(n-1)  // n-1개의 간선 선택

+k*(n-2)*(n-1)/2를 하면 된다.

 

최소 간선을 선택하는 과정은 위 코드와 같이 answer+=min_dist로 해주고, 마지막에 answer+=(long long)k*(n-2)*(n-1)/2를 해주면 된다.