题目描述
给你一个长为n的序列a
m次查询
每次查询一个区间的所有子区间的gcd的和mod1e9+7的结果
输入描述:
第一行两个数n,m
之后一行n个数表示a
之后m行每行两个数l,r表示查询的区间
输出描述:
对于每个询问,输出一行一个数表示答案
示例1
输入
5 7 30 60 20 20 20 1 1 1 5 2 4 3 4 3 5 2 5 2 3
输出
30 330 160 60 120 240 100
说明
[1,1]的子区间只有[1,1],其gcd为30
[1,5]的子区间有:
[1,1]=30,[1,2]=30,[1,3]=10,[1,4]=10,[1,5]=10
[2,2]=60,[2,3]=20,[2,4]=20,[2,5]=20
[3,3]=20,[3,4]=20,[3,5]=20
[4,4]=20,[4,5]=20
[5,5]=20
总共330
[2,4]的子区间有:
[2,2]=60,[2,3]=20,[2,4]=20
[3,3]=20,[3,4]=20
[4,4]=20
总共160
[3,4]的子区间有:
[3,3]=20,[3,4]=20
[4,4]=20
总共60
[3,5]的子区间有:
[3,3]=20,[3,4]=20,[3,5]=20
[4,4]=20,[4,5]=20
[5,5]=20
总共120
[2,5]的子区间有:
[2,2]=60,[2,3]=20,[2,4]=20,[2,5]=20
[3,3]=20,[3,4]=20,[3,5]=20
[4,4]=20,[4,5]=20
[5,5]=20
总共240
[2,3]的子区间有:
[2,2]=60,[2,3]=20
[3,3]=20
总共100
备注:
对于100%的数据,有1 <= n , m , ai <= 100000
题解
倍增预处理、莫队算法。
类似的题目做过好几个了,有一个比较重要的性质:以$i$为起点的区间,区间$gcd$的值只有$log(n)$种。
这样莫队转移的时候,只要把那$log(n)$种都算一下就$ok$了。
有点卡常,优化了一点才过。
#include <bits/stdc++.h> using namespace std; const int maxn = 1e5 + 10; const long long mod = 1e9 + 7; int a[maxn]; int pos[maxn]; int n, m, L, R; long long Ans; struct X { int l, r, id; }s[maxn]; long long ans[maxn]; struct P { int end1; int end2; long long sum; int GCD; int nx; }p[maxn * 40]; int cnt; int List[maxn][2]; /* st */ int dp[maxn][30]; int gcd(int a, int b) { if(b == 0) return a; return gcd(b, a % b); } void init() { for(int i = 1; i <= n; i ++) { dp[i][0] = a[i]; } for(int i = 1; (1 << i) <= n; i ++) { for(int j = 1; j + (1 << i) - 1 <= n; j ++) { dp[j][i] = gcd(dp[j][i - 1], dp[j + (1 << (i - 1))][i - 1]); } } } int query(int l, int r) { int k = (int)(log(double(r - l + 1)) / log((double)2)); return gcd(dp[l][k], dp[r - (1 << k) + 1][k]); } /* st */ bool cmp(const X& a, const X& b) { if (pos[a.l] != pos[b.l]) return a.l < b.l; if((pos[a.l]) & 1) return a.r > b.r; return a.r < b.r; } void add(int x, int op) { int it; for(it = List[x][op]; it != -1; it = p[it].nx) { int id = it; if(p[id].end2 < L || p[id].end1 > R) continue; if(p[id].end1 >= L && p[id].end2 <= R) { Ans = Ans + p[id].sum; } else { int ll = max(p[id].end1, L); int rr = min(p[id].end2, R); Ans = Ans + 1LL * (rr - ll + 1) * p[id].GCD; } } } void del(int x, int op) { int it; for(it = List[x][op]; it != -1; it = p[it].nx) { int id = it; if(p[id].end2 < L || p[id].end1 > R) continue; if(p[id].end1 >= L && p[id].end2 <= R) { Ans = Ans - p[id].sum; } else { int ll = max(p[id].end1, L); int rr = min(p[id].end2, R); Ans = Ans - 1LL * (rr - ll + 1) * p[id].GCD; } } } int main() { scanf("%d%d", &n, &m); int sz = sqrt(n); for(int i = 1; i <= n; i ++) { scanf("%d", &a[i]); pos[i] = i / sz; List[i][0] = List[i][1] = -1; } init(); for(int i = 1; i <= n; i ++) { int ll = i, rr = i; while(ll <= n) { int left = ll, right = n; int g = query(i, ll); while(left <= right) { int mid = (left + right) / 2; if(g == query(i, mid)) { rr = mid, left = mid + 1; } else { right = mid - 1; } } p[cnt].end1 = ll; p[cnt].end2 = rr; p[cnt].sum = 1LL * g * (rr - ll + 1); p[cnt].GCD = g; p[cnt].nx = List[i][0]; List[i][0] = cnt; ll = rr + 1; cnt ++; } } for(int i = 1; i <= n; i ++) { int ll = i, rr = i; while(rr >= 1) { int left = 1, right = rr; int g = query(rr, i); while(left <= right) { int mid = (left + right) / 2; if(g == query(mid, i)) { ll = mid, right = mid - 1; } else { left = mid + 1; } } p[cnt].end1 = ll; p[cnt].end2 = rr; p[cnt].sum = 1LL * g * (rr - ll + 1); p[cnt].GCD = g; p[cnt].nx = List[i][1]; List[i][1] = cnt; rr = ll - 1; cnt ++; } } for(int i = 1; i <= m; i ++) { scanf("%d%d", &s[i].l, &s[i].r); s[i].id = i; } sort(s + 1, s + m + 1, cmp); L = s[1].l; R = s[1].l - 1; Ans = 0; for(int i = s[1].l; i <= s[1].r; i ++) { R ++; add(i, 1); } ans[s[1].id] = Ans; for(int i = 2; i <= m; i ++) { while (L > s[i].l) { L --, add(L, 0); } while (R < s[i].r) { R ++, add(R, 1); } while (L < s[i].l) { del(L, 0), L ++; } while (R > s[i].r) { del(R, 1), R --; } ans[s[i].id] = Ans; } for(int i = 1; i <= m; i ++) { printf("%lld\n", ans[i] % mod); } return 0; }