Problem description.

You are given a tree. If we select 2 distinct nodes uniformly at random, what‘s the probability that the distance between these 2 nodes is a prime number?


The first line contains a number N: the number of nodes in this tree.
The following N-1 lines contain pairs a[i] and b[i], which means there is an edge with length 1 between a[i] and b[i].


Output a real number denote the probability we want.
You‘ll get accept if the difference between your answer and standard answer is no more than 10^-6.


2 ≤ N ≤ 50,000

The input must be a tree.


1 2
2 3
3 4
4 5



We have C(5, 2) = 10 choices, and these 5 of them have a prime distance:

1-3, 2-4, 3-5: 2

1-4, 2-5: 3

Note that 1 is not a prime number.


Source Limit: 50000

using namespace std;
typedef long long LL;
const double pi = acos(-1.0);
const int low(int x) { return x&-x; }
const int INF = 0x7FFFFFFF;
const int mod = 1e9 + 7;
const int maxn = 1e5 + 10;
int n, K, x, y, p[maxn], f[maxn], tot;

class FFT
	const static int maxn = 270000;//要注意长度是2^k方
	class Plural
		double x, y;
		Plural(double x = 0.0, double y = 0.0) :x(x), y(y) {}
		Plural operator +(const Plural &a)
			return Plural(x + a.x, y + a.y);
		Plural operator -(const Plural &a)
			return Plural(x - a.x, y - a.y);
		Plural operator *(const Plural &a)
			return Plural(x*a.x - y*a.y, x*a.y + y*a.x);
		Plural operator /(const double &u)
			return Plural(x / u, y / u);
	Plural x[maxn];// x1[maxn], x2[maxn];
	Plural y[maxn];// y1[maxn], y2[maxn];
	int X[maxn];
	int n, len;
	int reverse(int x)
		int ans = 0;
		for (int i = 1, j = n >> 1; j; i <<= 1, j >>= 1) if (x&i) ans |= j;
		return ans;
	Plural w(double x, double y)
		return Plural(cos(2 * pi * x / y), -sin(2 * pi * x / y));
	void setx(int len, int *c)
		this->len = len;
		for (n = len + len + 1; n != low(n); n += low(n));//这里要注意取值
		for (int i = 0; i < n; i++)
			if (i > len) x[i] = Plural(0, 0);
				x[i] = Plural(c[i], 0);
				X[i] = c[i];
	void fft(Plural*x, Plural*y, int flag)
		for (int i = 0; i < n; i++) y[i] = x[reverse(i)];
		for (int i = 1; i < n; i <<= 1)
			Plural uu = w(flag, i + i);
			for (int j = 0; j < n; j += i + i)
				Plural u(1, 0);
				for (int k = j; k < j + i; k++)
					Plural a = y[k];
					//w(flag*(k - j), i + i) 可以去掉u和uu用这个代替,精度高些,代价是耗时多了
					Plural b = u * y[k + i];
					y[k] = a + b;
					y[k + i] = a - b;
					u = u*uu;
		if (flag == -1) for (int i = 0; i < n; i++) y[i] = y[i] / n;
	LL solve()
		fft(x, y, 1);
		for (int i = 0; i < n; i++) y[i] = y[i] * y[i];
		fft(y, x, -1);
		LL res = 0, ans = 0;
		for (int i = 0, j; p[i] < n; i++)
			j = p[i];
			ans = (LL)(x[j].x + 0.5);//调整精度
			if (!(j & 1) && (j >> 1) <= len) ans -= X[j >> 1];
			ans >>= 1;
			res += ans;
		return res;

struct Tree
	int ft[maxn], nt[maxn], u[maxn], sz;
	int vis[maxn], cnt[maxn], mx[maxn], flag, h[maxn], fd[maxn];
	void clear(int n)
		mx[sz = flag = 0] = INF;
		for (int i = 1; i <= n; i++)
			ft[i] = -1;
			vis[i] = 0;
			h[i] = 0;
	void AddEdge(int x, int y)
		u[sz] = y; nt[sz] = ft[x]; ft[x] = sz++;
	int dfs(int x, int fa, int sum)
		int ans = mx[x] = 0;
		cnt[x] = 1;
		for (int i = ft[x]; i != -1; i = nt[i])
			if (vis[u[i]] || u[i] == fa) continue;
			int y = dfs(u[i], x, sum);
			if (mx[y]<mx[ans]) ans = y;
			cnt[x] += cnt[u[i]];
			mx[x] = max(mx[x], cnt[u[i]]);
		mx[x] = max(mx[x], sum - cnt[x]);
		return mx[x] < mx[ans] ? x : ans;
	int get(int x, int fa, int dep)
		int ans = dep;
		if (h[dep] != flag) h[dep] = flag, fd[dep] = 0;
		for (int i = ft[x]; i != -1; i = nt[i])
			if (u[i] == fa || vis[u[i]]) continue;
			ans = max(ans, get(u[i], x, dep + 1));
		return ans;
	LL find(int x, int dep)
		++flag;	fd[0] = 0;
		int len = get(x, -1, dep);
		fft.setx(len, fd);
		return fft.solve();
	LL work(int x, int sum)
		int y = dfs(x, -1, sum);
		LL ans = find(y, 0);  vis[y] = 1;
		for (int i = ft[y]; i != -1; i = nt[i])
			if (vis[u[i]]) continue;
			if (cnt[u[i]] > cnt[y]) cnt[u[i]] = sum - cnt[y];
			ans -= find(u[i], 1);
			ans += work(u[i], cnt[u[i]]);
		return ans;

void read(int &x)
	char ch;
	while ((ch = getchar()) < '0' || ch > '9');
	x = ch - '0';
	while ((ch = getchar()) >= '0' && ch <= '9') x = x * 10 + ch - '0';

void init()
	f[0] = f[1] = 1;
	for (int i = 2; i < maxn; i++)
		if (!f[i]) p[tot++] = i;
		for (int j = 0; j < tot&&i*p[j]<maxn; j++)
			f[i*p[j]] = 1;
			if (i%p[j] == 0) break;

int main()
	while (~scanf("%d", &n))
		for (int i = 1; i < n; i++)
			read(x);	read(y);
			solve.AddEdge(x, y);
			solve.AddEdge(y, x);
		printf("%.6lf\n", solve.work(1, n) * 2.0 / n / (n - 1));
	return 0;

