BZOJ-3992

题意

集合 $S\subseteq \{x|x\in[0,m-1]\}$ ,由这些数组成一个长度为 $N$ 的数列,给定整数 $r(r\in[0,m-1])$,求满足数列中所有数的乘积 $\mod m$ 的值等于 $r$ 的不同的数列的有多少个。

题解

将 $S$ 和 $x$ 由 $m$ 的原根来表示,这样就可以变乘为加,将集合 $S$ 表示为 $\large f(x)=a_0x^0+a_1x^1+\cdots+a_{m-2}x^{m-2}$ , $a_i$ 表示 $g^i$ 是否属于 $S$ 将 $f(x)$ 看作一个整体,使用快速幂和 $NTT$ 求解 $f(x)^n$ ,答案就为 $x^r$ 的系数

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
ll inf =0x3f3f3f3f;
ll mod = 1004535809;
const int M = 2.1e5;
const int N = 4.1e4;

ll qPow(ll a, ll b, ll c) { //求(a^b) % c
ll ret = 1;
while (b) {
if (b & 0x1) ret = ret * a % c;
a = a * a % c;
b >>= 1;
}
return ret;
}
int factor[50];
int num_factor = 0;
void decomposition(ll n) {
num_factor = 0;
int m=(int)sqrt(n+0.5);
for(int i=2;i<=m;i++){
if(n%i==0){
factor[num_factor]=i;
while(n%i==0)n/=i;
num_factor++;
}
}
if (n > 1) factor[num_factor++] =n;
}
int root(int m){//求a在模m意义下的阶
int phi=m-1;
decomposition(phi);
int g=2;
while(1){
bool yes=1;
for0(i,num_factor){
if(qPow(g,phi/factor[i],m)==1){
yes=0;
break;
}
}
if(yes)break;
g++;
}
return g;
}
int Map[N];
int n,r[N];
ll g=3,inv_g,inv_n;
void init(int a,int b){
int L=0;n=1;
while(n < a + b) n <<= 1,L++;
inv_g=qPow(3,mod-2,mod);
inv_n = qPow(n,mod-2,mod);
for(int i = 0; i < n; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));
}
inline void NTT(int *A, int type) {//type=1 系数->点值 ; type=-1 点值->系数
for(int i = 0; i < n; i++) if(i < r[i]) swap(A[i], A[r[i]]);
for(int mid = 1; mid < n; mid <<= 1) {
ll Wn = qPow( type == 1 ? g : inv_g , (mod - 1) / (mid << 1),mod);
for(int j = 0; j < n; j += (mid << 1)) {
ll w = 1;
for(int k = 0; k < mid; k++, w = (w * Wn) % mod) {
int x = A[j + k], y = w * A[j + k + mid] % mod;
A[j + k] = (x + y) % mod,
A[j + k + mid] = (x - y + mod) % mod;
}
}
}
if(type==-1) for0(i,n)A[i]=A[i]*inv_n%mod;
}
int ans[N],a[N],b[N];
int m;
void mult(int a[],int b[]){
NTT(a,1);
NTT(b,1);
for(int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mod;
NTT(a, -1);
for(int i=m-1;i<2*m-1;i++){
a[i%(m-1)]=(a[i%(m-1)]+a[i])%mod;
a[i]=0;
}
}
int main() {
int N,x,s,y;
in(N,m,x,s);
int rt=root(m),tm=1;
for(int i=0;i<m-1;i++){
Map[tm]=i;
tm=tm*rt%m;
}
Map[0]=m-1;
x=Map[x];
bool have0=0;
for0(i,s){
in(y);
a[Map[y]]=1;
if(y==0)have0=1;
}
if(x==m-1){
if(have0){
out((qPow(s,N,mod)-qPow(s-1,N,mod)+mod)%mod,1);
}else out(0,1);
return 0;
}
ans[0]=1;a[m-1]=0;
init(m-1,m-1);
while(N){
if(N%2){
memcpy(b,a,sizeof(int)*(2*m));
memcp(b,a);
mult(ans,b);
}
memcpy(b,a,sizeof(int)*(2*m));
mult(a,b);
N/=2;
}
out(ans[x],1);
return 0;
}