Blackops

初心易得,始终难守

0%

HDU 1402 A * B Problem Plus(FFT入门题)

A * B Problem Plus

Time Limit: 2000/1000 MS (Java/Others) Memory Limit: 65536/32768 K (Java/Others)Total Submission(s): 25113 Accepted Submission(s): 6448

Problem Description

Calculate A * B.

Input

Each line will contain two integers A and B. Process to end of file.
Note: the length of each integer will not exceed 50000.

Output

For each case, output A * B in one line.

Sample Input

1
2
1000
2

Sample Output

2
2000

题目链接:HDU 1402

题意就不说了,主要是讲下如何使用快速傅立叶变换做这入门题。

快速傅立叶变换(FFT)可以用来快速地求两个多项式的积,就像$(1+x+2x^2)\times (3x+6x^2)=3x+6x^2+3x^2+6x^3+3x^3+6x^4=3x+9x^2+9x^3+6x^4$

输入是两个多项式的系数,输出是这两个多项式之积的系数(FFT我一开始都不知道它是怎么用的,更不用说去学了)

那么上述的输入就是 $1\;1\;2$ 和 $0\;3\;6\;$,得到的多项式系数是$0\;3\;9\;9\;6\;0\;0\;0$(注意最低系数要补齐,从$x^0$开始,结果的次数界要为$2$的幂次)

各种介绍的博客就不说了,百度搜 FFT学习笔记一大堆。这题如果把输入的数看成十进制下的带权求和多项式,那么就可以以多项式的乘法来得到答案的多项式表示,再把这个多项式用十进制转换成答案即可。

以$1000 \times2$为例,它就是

这两个多项式的乘积,不过不能一开始就把$10$进制这个$10$代入,应该写成

然后把结果的系数用FFT求出来即:

再把$x=10$代入即可。

这里有几个细节问题,如果输入的多项式次数界为$a$和$b$,那么结果的次数界应该为$a+b-1$,那么FFT时所用的$2$的幂次的次数界应该要刚好大于等于它。


又去学了下优化FFT的姿势,把递归版改成了非递归,原理就是直接把一开始的数组按照递归合并时的顺序排序,然后就做一个类似于倍增的合并操作就行了。

学习的时候发现有几个要注意和改进的地方:

  1. 重复利用某一个complex数组的时候,在当前次数界$n$之后的复数要手动置$0$,否则答案越到后面会偏差越大
  2. 函数参数$f$表示正逆变换,做逆变换的时候可以把要返回的数组乘以$1 \over n$,方便一点。
  3. 可以预处理出所有要用的$cos(2\pi/i)$和$sin(2\pi/i)$的值,速度快一点。
  4. FFT函数可以写成返回一个数组的首地址,这样就可以用指针接收这个地址,方便后续进行操作(数据一大估计得delete 一下23333)
  5. ans数组如果多次使用,其末尾也要清零,否则末尾会遗留下上次可能超过末尾的部分数字。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#include <bits/stdc++.h>
//#pragma GCC optimize(2)
using namespace std;
#define INF 0x3f3f3f3f
#define LC(x) (x<<1)
#define RC(x) ((x<<1)+1)
#define MID(x,y) ((x+y)>>1)
#define all(a) (a).begin(),(a).end()
#define pb(x) push_back(x)
#define CLR(arr,val) memset(arr,val,sizeof(arr))
#define FAST_IO ios::sync_with_stdio(false);cin.tie(0);
#define caseT int _T;scanf("%d",&_T);for (int q=1; q<=_T; ++q)
typedef pair<int, int> pii;
typedef complex<double> cpx;
typedef long long LL;
const double PI = acos(-1.0);
const int N = 50005;
char a[N], b[N];
int R[N << 2], ans[N << 1];
cpx x[N << 2], pwm[20];//pwm是预处理出2*PI/(1<<i)的正弦和余弦值组成的复数

int rpos(int x, int n)//以n位二进制表示的x的反转之后的值
{
int w = 0;
for (int i = 0; (1 << i) < n; ++i)
w = (w << 1) | ((x >> i) & 1);
return w;
}
cpx* FFT(cpx a[], int n, int f)
{
cpx *A = new cpx[n];
for (int i = 0; i < n; ++i)
A[i] = a[R[i]];
for (int i = 1; (1 << i) <= n; ++i)
{
int m = (1 << i);
// cpx wm(cos(2 * PI / m), f * sin(2 * PI / m));
cpx wm = pwm[i];
if (f == -1)
wm.imag(-wm.imag());
for (int k = 0; k < n; k += m)
{
cpx w(1, 0);
for (int j = 0; j < (m >> 1); ++j)
{
cpx t = w * A[k + j + (m >> 1)], u = A[k + j];
A[k + j] = u + t;
A[k + j + (m >> 1)] = u - t;
w *= wm;
}
}
}
if (!~f)
for (int i = 0; i < n; ++i)
A[i].real(A[i].real() / n);
return A;//返回处理好的数组首地址,后面就可以用了
}
int main(void)
{
int i;
for (i = 0; (1 << i) < (N << 2); ++i)
pwm[i] = cpx(cos(2 * PI / (1 << i)), sin(2 * PI / (1 << i)));
while (~scanf("%s%s", a, b))
{
int la = strlen(a), lb = strlen(b), lc = la + lb - 1, n = 1;
while (n < lc)
n <<= 1;
for (i = 0; i < n; ++i)
R[i] = rpos(i, n);
for (i = 0; i < la; ++i)
x[i] = cpx(a[la - 1 - i] - '0', 0);
for (i = la; i < n; ++i)//记得清零
x[i] = cpx(0, 0);
cpx *A = FFT(x, n, 1);
for (i = 0; i < lb; ++i)
x[i] = cpx(b[lb - 1 - i] - '0', 0);
for (i = lb; i < n; ++i)//记得清零
x[i] = cpx(0, 0);
cpx *B = FFT(x, n, 1);
for (i = 0; i < n; ++i)
A[i] *= B[i];
A = FFT(A, n, -1);
ans[lc] = 0;//记得清零
for (i = 0; i < n; ++i)
ans[i] = int(A[i].real() + 0.5);
for (i = 0; i < lc; ++i)
ans[i + 1] += ans[i] / 10, ans[i] %= 10;
while (lc && !ans[lc])
--lc;
for (i = lc; i >= 0; --i)
printf("%d", ans[i]);
puts("");
delete []A;
delete []B;
}
return 0;
}