G - Takahashi And Pass-The-Ball Game Editorial by evima

Another Solution

You can also consider a directed graph (functional graph) where an edge extends from vertex \(i\) to vertex \(A_i\) for each \(i\). For each vertex \(i\), you will add \(B_i/K\) to a cumulative total of \(K\) vertices reached by tracing \(1, 2, \dots, K\) edges from \(i\). Here, a functional graph consists of several cycles and parts of trees rooted at the vertices belonging to these cycles, and since you will reach a cycle from any vertex after tracing several edges, the breakdown of the aforementioned “cumulative total of \(K\) vertices” is not very complicated. By using cumulative sum-related techniques to perform addition from vertex \(i\) to vertices on the cycle all at once and to other vertices (see the sample implementation), the main part of the problem can be solved in linear time.

Sample Implementation (Python)

import sys

N, K = map(int, input().split())
sys.setrecursionlimit(2 * N + 99)
MOD = 998244353
A = list(map(lambda x: int(x) - 1, input().split()))
B = list(map(int, input().split()))
rA = [[] for _ in range(N)]
for i in range(N):
    rA[A[i]].append(i)
visited = [False for _ in range(N)]
ans = [0 for _ in range(N)]


def solve(r):
    m = len(r)
    s = set(r)
    a = [0 for _ in range(m + 1)]
    l = []
    loopsK, remK = K // m, K % m

    def dfs(v, pos):
        visited[v] = True
        p = (pos - len(l)) % m
        a[0] += loopsK * B[v] % MOD
        toK = p + remK
        if toK < m:
            a[p + 1] += B[v]
            a[toK + 1] -= B[v]
        else:
            a[0] += B[v]
            a[toK - m + 1] -= B[v]
            a[p + 1] += B[v]
        if len(l) > 1:
            loopsL, remL = min(K, len(l) - 1) // m, min(K, len(l) - 1) % m
            a[0] += loopsL * -B[v] % MOD
            toL = p + remL
            if toL < m:
                a[p + 1] += -B[v]
                a[toL + 1] -= -B[v]
            else:
                a[0] += -B[v]
                a[toL - m + 1] -= -B[v]
                a[p + 1] += -B[v]
        if len(l) > 1:
            ans[l[-1]] += B[v]
        if len(l) > K + 1:
            ans[l[-K - 1]] -= B[v]
        l.append(v)
        for w in rA[v]:
            if w not in s:
                dfs(w, pos)
                ans[v] += ans[w]
        l.pop()

    for i in range(m):
        dfs(r[i], i)

    for i in range(m):
        ans[r[i]] = a[i]
        a[i + 1] += a[i]


for i in range(N):
    if not visited[i]:
        cur = i
        l = []
        while not visited[cur]:
            visited[cur] = True
            l.append(cur)
            cur = A[cur]
        solve(l[l.index(cur):])
d = pow(K, MOD - 2, MOD)
print(' '.join(map(lambda x: str(x % MOD * d % MOD), ans)))

posted:
last update: