First of all, a simple idea is to make \ (dp_{i, j} \) as the maximum benefit of selecting \ (j \) black spots in the subtree with \ (I \) as the root, but you will find that no matter how to transfer this thing, it needs to involve depth and sum. If we record depth and sum, we can't pass this problem.

We might as well split each side of each path into a total contribution. It is not difficult to find that we only need to know how many black dots there are on one side, so we can make \ (dp_{i, j} \) as the black dot in the subtree with \ (I \) as the root, and have the maximum contribution of all edges in the subtree with \ (J \) \ (I \). It can be found that this is a good transfer. We need to pay attention to the second layer of tree knapsack enumeration. If we need to enumerate \ (k = 0 \), we need to transfer it in advance or in order, because if we transfer out \ (DP {I, J} \) first and then use \ (DP {I, j + 0} \) transfer \ (DP {I, J} \) to calculate the contribution repeatedly.

However, the above method is lost by the chain card. The standard tree \ (dp \) enumeration method is actually like this. For each sub tree \ (v_i \), we enumerate to \ (\ sum \ limits {J = 1} ^ {I - 1} s {v_j} \) in the first layer and to \ (s {v_i} \) in the second layer, so the complexity is right. Because \ (\ sum \ limits \ sum \ limits s {v_i} \ times \ sum \ limits {J = 1} ^ {I - 1} s {v_j} = \ sum \ limits \ sum \ limits \ sum \ limits s {v_i} \ times {v_j} (v_i \ ne v_j) \) is equivalent to the number of point pairs on the tree, which is of course \ (n ^ 2 \) level. In addition, in this case, our \ (K \) should be enumerated in reverse order, because if we consider \ (k = 0 \) first, then our \ (dp {I, J} \) will be added to the contribution of \ (dp {I, j + 0} + \ cdots \), and the calculation will be repeated.

#include<bits/stdc++.h> using namespace std; #define N 2000 + 5 #define int long long #define rep(i, l, r) for(int i = l; i <= r; ++i) #define dep(i, l, r) for(int i = r; i >= l; --i) #define Next(i, u) for(int i = h[u]; i; i = e[i].next) struct edge{ int v, next, w; }e[N << 1]; int n, u, v, w, T, tot, s[N], h[N], dp[N][N]; int read(){ char c; int x = 0, f = 1; c = getchar(); while(c > '9' || c < '0'){ if(c == '-') f = -1; c = getchar();} while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar(); return x * f; } void add(int u, int v, int w){ e[++tot].v = v, e[tot].w = w, e[tot].next = h[u], h[u] = tot; e[++tot].v = u, e[tot].w = w, e[tot].next = h[v], h[v] = tot; } void dfs(int u, int fa){ s[u] = 1, dp[u][0] = dp[u][1] = 0; Next(i, u){ int v = e[i].v; if(v == fa) continue; dfs(v, u); dep(j, 0, min(T, s[u])) dep(k, 0, min(T - j, s[v])) if(dp[v][k] != -1 && dp[u][j] != -1){ int val = (T - k) * k * e[i].w + (s[v] - k) * (n - s[v] - T + k) * e[i].w; dp[u][j + k] = max(dp[u][j + k], dp[v][k] + dp[u][j] + val); } s[u] += s[v]; } } signed main(){ n = read(), T = read(), memset(dp, -1, sizeof(dp)); rep(i, 1, n - 1) u = read(), v = read(), w = read(), add(u, v, w); dfs(1, 0); printf("%lld", dp[1][T]); return 0; }