help
求 ∑n∈[1,N]∑m∈[1,M]∑k∈[0,m) [(nk+x)/m] (N,M<=500000,x为实数且x∈[0,100000])
0.首先令x = [x]显然是不影响结果的
1.先考虑如何化简∑k∈[0,m) [(nk+x)/m]
首先显然有:[(nk+x)/m] = [(nk%m+x)/m]+[nk/m]
[(nk+x)/m] = [(nk%m+x)/m]+(nk-nk%m)/m
则∑k∈[0,m) [(nk+x)/m] = ∑k∈[0,m) [(nk%m+x)/m] + ∑k∈[0,m) nk/m + ∑k∈[0,m) nk%m/m
令d=gcd(n,m),则∑k∈[0,m) [(nk%m+x)/m] = d([x/m]+[(x+d)/m]+...+[(x+m-d)/m])
= d([(x/d)/(m/d)]+[((x/d)+1)/(m/d)]+...+[((x/d)+m/d-1)/(m/d)])
= d([(x/d)(m/d)/(m/d)] (注)
= d[x/d]
∑k∈[0,m) nk%m = d(0+d+...+m-d)
= (m-d)/2
∑k∈[0,m) [(nk+x)/m] = d[x/d]+n(m-1)/2+(d-m)/2
= d[x/d]+(n-1)(m-1)/2+(d-1)/2
由此可见:∑k∈[0,m) [(nk+x)/m] = ∑k∈[0,n) [(mk+x)/n]
注:
对于正整数m和实数x, [mx] = [x] + [x + 1/m] + ... + [x + (m-1)/m]
证明:
x = [x] + {x}
设:k = [m{x}]
则k/m <= m{x} < (k+1)/m
[mx] = [m([x] + {x})]
= m[x] + [m{x}]
= m[x] + k
[x]+[x+1/m]+...+[x+(m-1)/m]
= m[x] + [{x}+1/m] + ... + [{x}+(m-1)/m]
= m[x] + k
[mx] = [x] + [x + 1/m] + ... + [x + (m-1)/m]
2. 再考虑如何计算∑n∈[1,N]∑m∈[1,M] d[x/d]+(n-1)(m-1)/2+(d-1)/2 (d=gcd(n,m))
分为3部分计算:
1) ∑n∈[1,N]∑m∈[1,M] d[x/d]
2) ∑n∈[1,N]∑m∈[1,M] (n-1)(m-1)/2
3) ∑n∈[1,N]∑m∈[1,M] (d-1)/2
1) 和 3) 考虑容斥,枚举d , 再枚举d的倍数,复杂度O(nlogn)
2) 可以直接O(1)计算
3.码码码
#include <vector>
#include <stdio.h>
#include <string.h>
#include <iostream>
#include <algorithm>
#define P 998244353
#define r2 499122177
#define MAXN 500005
using namespace std;
#define g() getchar()
template<class Q>inline void Scan(Q&x){
char c; int f=1;
while(c=g(),c<48||c>57)if(c=='-')f=-1;
for(x=0;c>47&&c<58;c=g())x=10*x+c-48;
x*=f;
}
typedef long long ll;
ll n,m,x;
ll p1,p2,p3;
bool ban[MAXN];
vector<int>p;
int mu[MAXN],tot;
inline void get_mu(int n){
mu[1]=1;
for(int i=2;i<=n;++i){
if(!ban[i]){
++tot;
p.push_back(i);
mu[i]=-1;
}
for(int j=0;j<tot;++j){
if((ll)i*p[j]>n)break;
int x=i*p[j];
ban[x]=1;
if(i%p[j]){
mu[x]=P-mu[i];
}
else{
mu[x]=0;
break;
}
}
}
}
inline void set_IO(){
freopen("help.in","r",stdin);
freopen("help.out","w",stdout);
}
int main(){
set_IO();
Scan(n),Scan(m);
if(n>m)swap(n,m);
get_mu(n);
Scan(x);
p3=(P-n*m%P)%P;
for(int d=1;d<=n;++d)
for(int k=1,D;(D=k*d)<=n;++k){
int a=n/D,b=m/D;
p1=(p1+x/d*d*a%P*b%P*mu[k]%P)%P;
p3=(p3+(ll)d*a%P*b%P*mu[k]%P)%P;
}
p2=((n*(n-1)>>1)%P)*((m*(m-1)>>1)%P)%P;
ll ans=(p1+(p2+p3)*r2%P)%P;
cout<<ans<<endl;
return 0;
}