G - Takahashi And Pass-The-Ball Game Editorial by evima
Another SolutionYou can also consider a directed graph (functional graph) where an edge extends from vertex \(i\) to vertex \(A_i\) for each \(i\). For each vertex \(i\), you will add \(B_i/K\) to a cumulative total of \(K\) vertices reached by tracing \(1, 2, \dots, K\) edges from \(i\). Here, a functional graph consists of several cycles and parts of trees rooted at the vertices belonging to these cycles, and since you will reach a cycle from any vertex after tracing several edges, the breakdown of the aforementioned “cumulative total of \(K\) vertices” is not very complicated. By using cumulative sum-related techniques to perform addition from vertex \(i\) to vertices on the cycle all at once and to other vertices (see the sample implementation), the main part of the problem can be solved in linear time.
Sample Implementation (Python)
import sys
N, K = map(int, input().split())
sys.setrecursionlimit(2 * N + 99)
MOD = 998244353
A = list(map(lambda x: int(x) - 1, input().split()))
B = list(map(int, input().split()))
rA = [[] for _ in range(N)]
for i in range(N):
rA[A[i]].append(i)
visited = [False for _ in range(N)]
ans = [0 for _ in range(N)]
def solve(r):
m = len(r)
s = set(r)
a = [0 for _ in range(m + 1)]
l = []
loopsK, remK = K // m, K % m
def dfs(v, pos):
visited[v] = True
p = (pos - len(l)) % m
a[0] += loopsK * B[v] % MOD
toK = p + remK
if toK < m:
a[p + 1] += B[v]
a[toK + 1] -= B[v]
else:
a[0] += B[v]
a[toK - m + 1] -= B[v]
a[p + 1] += B[v]
if len(l) > 1:
loopsL, remL = min(K, len(l) - 1) // m, min(K, len(l) - 1) % m
a[0] += loopsL * -B[v] % MOD
toL = p + remL
if toL < m:
a[p + 1] += -B[v]
a[toL + 1] -= -B[v]
else:
a[0] += -B[v]
a[toL - m + 1] -= -B[v]
a[p + 1] += -B[v]
if len(l) > 1:
ans[l[-1]] += B[v]
if len(l) > K + 1:
ans[l[-K - 1]] -= B[v]
l.append(v)
for w in rA[v]:
if w not in s:
dfs(w, pos)
ans[v] += ans[w]
l.pop()
for i in range(m):
dfs(r[i], i)
for i in range(m):
ans[r[i]] = a[i]
a[i + 1] += a[i]
for i in range(N):
if not visited[i]:
cur = i
l = []
while not visited[cur]:
visited[cur] = True
l.append(cur)
cur = A[cur]
solve(l[l.index(cur):])
d = pow(K, MOD - 2, MOD)
print(' '.join(map(lambda x: str(x % MOD * d % MOD), ans)))
posted:
last update: