blob: 9d464777891a3149ac9219870489e20ba84b090a [file] [log] [blame]
Austin Schuhdace2a62020-08-18 10:56:48 -07001/* mpn_perfect_power_p -- mpn perfect power detection.
2
3 Contributed to the GNU project by Martin Boij.
4
5Copyright 2009, 2010, 2012, 2014 Free Software Foundation, Inc.
6
7This file is part of the GNU MP Library.
8
9The GNU MP Library is free software; you can redistribute it and/or modify
10it under the terms of either:
11
12 * the GNU Lesser General Public License as published by the Free
13 Software Foundation; either version 3 of the License, or (at your
14 option) any later version.
15
16or
17
18 * the GNU General Public License as published by the Free Software
19 Foundation; either version 2 of the License, or (at your option) any
20 later version.
21
22or both in parallel, as here.
23
24The GNU MP Library is distributed in the hope that it will be useful, but
25WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
26or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
27for more details.
28
29You should have received copies of the GNU General Public License and the
30GNU Lesser General Public License along with the GNU MP Library. If not,
31see https://www.gnu.org/licenses/. */
32
33#include "gmp-impl.h"
34#include "longlong.h"
35
36#define SMALL 20
37#define MEDIUM 100
38
39/* Return non-zero if {np,nn} == {xp,xn} ^ k.
40 Algorithm:
41 For s = 1, 2, 4, ..., s_max, compute the s least significant limbs of
42 {xp,xn}^k. Stop if they don't match the s least significant limbs of
43 {np,nn}.
44
45 FIXME: Low xn limbs can be expected to always match, if computed as a mod
46 B^{xn} root. So instead of using mpn_powlo, compute an approximation of the
47 most significant (normalized) limb of {xp,xn} ^ k (and an error bound), and
48 compare to {np, nn}. Or use an even cruder approximation based on fix-point
49 base 2 logarithm. */
50static int
51pow_equals (mp_srcptr np, mp_size_t n,
52 mp_srcptr xp,mp_size_t xn,
53 mp_limb_t k, mp_bitcnt_t f,
54 mp_ptr tp)
55{
56 mp_bitcnt_t y, z;
57 mp_size_t bn;
58 mp_limb_t h, l;
59
60 ASSERT (n > 1 || (n == 1 && np[0] > 1));
61 ASSERT (np[n - 1] > 0);
62 ASSERT (xn > 0);
63
64 if (xn == 1 && xp[0] == 1)
65 return 0;
66
67 z = 1 + (n >> 1);
68 for (bn = 1; bn < z; bn <<= 1)
69 {
70 mpn_powlo (tp, xp, &k, 1, bn, tp + bn);
71 if (mpn_cmp (tp, np, bn) != 0)
72 return 0;
73 }
74
75 /* Final check. Estimate the size of {xp,xn}^k before computing the power
76 with full precision. Optimization: It might pay off to make a more
77 accurate estimation of the logarithm of {xp,xn}, rather than using the
78 index of the MSB. */
79
80 MPN_SIZEINBASE_2EXP(y, xp, xn, 1);
81 y -= 1; /* msb_index (xp, xn) */
82
83 umul_ppmm (h, l, k, y);
84 h -= l == 0; --l; /* two-limb decrement */
85
86 z = f - 1; /* msb_index (np, n) */
87 if (h == 0 && l <= z)
88 {
89 mp_limb_t *tp2;
90 mp_size_t i;
91 int ans;
92 mp_limb_t size;
93 TMP_DECL;
94
95 size = l + k;
96 ASSERT_ALWAYS (size >= k);
97
98 TMP_MARK;
99 y = 2 + size / GMP_LIMB_BITS;
100 tp2 = TMP_ALLOC_LIMBS (y);
101
102 i = mpn_pow_1 (tp, xp, xn, k, tp2);
103 if (i == n && mpn_cmp (tp, np, n) == 0)
104 ans = 1;
105 else
106 ans = 0;
107 TMP_FREE;
108 return ans;
109 }
110
111 return 0;
112}
113
114
115/* Return non-zero if N = {np,n} is a kth power.
116 I = {ip,n} = N^(-1) mod B^n. */
117static int
118is_kth_power (mp_ptr rp, mp_srcptr np,
119 mp_limb_t k, mp_srcptr ip,
120 mp_size_t n, mp_bitcnt_t f,
121 mp_ptr tp)
122{
123 mp_bitcnt_t b;
124 mp_size_t rn, xn;
125
126 ASSERT (n > 0);
127 ASSERT ((k & 1) != 0 || k == 2);
128 ASSERT ((np[0] & 1) != 0);
129
130 if (k == 2)
131 {
132 b = (f + 1) >> 1;
133 rn = 1 + b / GMP_LIMB_BITS;
134 if (mpn_bsqrtinv (rp, ip, b, tp) != 0)
135 {
136 rp[rn - 1] &= (CNST_LIMB(1) << (b % GMP_LIMB_BITS)) - 1;
137 xn = rn;
138 MPN_NORMALIZE (rp, xn);
139 if (pow_equals (np, n, rp, xn, k, f, tp) != 0)
140 return 1;
141
142 /* Check if (2^b - r)^2 == n */
143 mpn_neg (rp, rp, rn);
144 rp[rn - 1] &= (CNST_LIMB(1) << (b % GMP_LIMB_BITS)) - 1;
145 MPN_NORMALIZE (rp, rn);
146 if (pow_equals (np, n, rp, rn, k, f, tp) != 0)
147 return 1;
148 }
149 }
150 else
151 {
152 b = 1 + (f - 1) / k;
153 rn = 1 + (b - 1) / GMP_LIMB_BITS;
154 mpn_brootinv (rp, ip, rn, k, tp);
155 if ((b % GMP_LIMB_BITS) != 0)
156 rp[rn - 1] &= (CNST_LIMB(1) << (b % GMP_LIMB_BITS)) - 1;
157 MPN_NORMALIZE (rp, rn);
158 if (pow_equals (np, n, rp, rn, k, f, tp) != 0)
159 return 1;
160 }
161 MPN_ZERO (rp, rn); /* Untrash rp */
162 return 0;
163}
164
165static int
166perfpow (mp_srcptr np, mp_size_t n,
167 mp_limb_t ub, mp_limb_t g,
168 mp_bitcnt_t f, int neg)
169{
170 mp_ptr ip, tp, rp;
171 mp_limb_t k;
172 int ans;
173 mp_bitcnt_t b;
174 gmp_primesieve_t ps;
175 TMP_DECL;
176
177 ASSERT (n > 0);
178 ASSERT ((np[0] & 1) != 0);
179 ASSERT (ub > 0);
180
181 TMP_MARK;
182 gmp_init_primesieve (&ps);
183 b = (f + 3) >> 1;
184
185 TMP_ALLOC_LIMBS_3 (ip, n, rp, n, tp, 5 * n);
186
187 MPN_ZERO (rp, n);
188
189 /* FIXME: It seems the inverse in ninv is needed only to get non-inverted
190 roots. I.e., is_kth_power computes n^{1/2} as (n^{-1})^{-1/2} and
191 similarly for nth roots. It should be more efficient to compute n^{1/2} as
192 n * n^{-1/2}, with a mullo instead of a binvert. And we can do something
193 similar for kth roots if we switch to an iteration converging to n^{1/k -
194 1}, and we can then eliminate this binvert call. */
195 mpn_binvert (ip, np, 1 + (b - 1) / GMP_LIMB_BITS, tp);
196 if (b % GMP_LIMB_BITS)
197 ip[(b - 1) / GMP_LIMB_BITS] &= (CNST_LIMB(1) << (b % GMP_LIMB_BITS)) - 1;
198
199 if (neg)
200 gmp_nextprime (&ps);
201
202 ans = 0;
203 if (g > 0)
204 {
205 ub = MIN (ub, g + 1);
206 while ((k = gmp_nextprime (&ps)) < ub)
207 {
208 if ((g % k) == 0)
209 {
210 if (is_kth_power (rp, np, k, ip, n, f, tp) != 0)
211 {
212 ans = 1;
213 goto ret;
214 }
215 }
216 }
217 }
218 else
219 {
220 while ((k = gmp_nextprime (&ps)) < ub)
221 {
222 if (is_kth_power (rp, np, k, ip, n, f, tp) != 0)
223 {
224 ans = 1;
225 goto ret;
226 }
227 }
228 }
229 ret:
230 TMP_FREE;
231 return ans;
232}
233
234static const unsigned short nrtrial[] = { 100, 500, 1000 };
235
236/* Table of (log_{p_i} 2) values, where p_i is the (nrtrial[i] + 1)'th prime
237 number. */
238static const double logs[] =
239 { 0.1099457228193620, 0.0847016403115322, 0.0772048195144415 };
240
241int
242mpn_perfect_power_p (mp_srcptr np, mp_size_t n)
243{
244 mp_limb_t *nc, factor, g;
245 mp_limb_t exp, d;
246 mp_bitcnt_t twos, count;
247 int ans, where, neg, trial;
248 TMP_DECL;
249
250 neg = n < 0;
251 if (neg)
252 {
253 n = -n;
254 }
255
256 if (n == 0 || (n == 1 && np[0] == 1)) /* Valgrind doesn't like
257 (n <= (np[0] == 1)) */
258 return 1;
259
260 TMP_MARK;
261
262 count = 0;
263
264 twos = mpn_scan1 (np, 0);
265 if (twos != 0)
266 {
267 mp_size_t s;
268 if (twos == 1)
269 {
270 return 0;
271 }
272 s = twos / GMP_LIMB_BITS;
273 if (s + 1 == n && POW2_P (np[s]))
274 {
275 return ! (neg && POW2_P (twos));
276 }
277 count = twos % GMP_LIMB_BITS;
278 n -= s;
279 np += s;
280 if (count > 0)
281 {
282 nc = TMP_ALLOC_LIMBS (n);
283 mpn_rshift (nc, np, n, count);
284 n -= (nc[n - 1] == 0);
285 np = nc;
286 }
287 }
288 g = twos;
289
290 trial = (n > SMALL) + (n > MEDIUM);
291
292 where = 0;
293 factor = mpn_trialdiv (np, n, nrtrial[trial], &where);
294
295 if (factor != 0)
296 {
297 if (count == 0) /* We did not allocate nc yet. */
298 {
299 nc = TMP_ALLOC_LIMBS (n);
300 }
301
302 /* Remove factors found by trialdiv. Optimization: If remove
303 define _itch, we can allocate its scratch just once */
304
305 do
306 {
307 binvert_limb (d, factor);
308
309 /* After the first round we always have nc == np */
310 exp = mpn_remove (nc, &n, np, n, &d, 1, ~(mp_bitcnt_t)0);
311
312 if (g == 0)
313 g = exp;
314 else
315 g = mpn_gcd_1 (&g, 1, exp);
316
317 if (g == 1)
318 {
319 ans = 0;
320 goto ret;
321 }
322
323 if ((n == 1) & (nc[0] == 1))
324 {
325 ans = ! (neg && POW2_P (g));
326 goto ret;
327 }
328
329 np = nc;
330 factor = mpn_trialdiv (np, n, nrtrial[trial], &where);
331 }
332 while (factor != 0);
333 }
334
335 MPN_SIZEINBASE_2EXP(count, np, n, 1); /* log (np) + 1 */
336 d = (mp_limb_t) (count * logs[trial] + 1e-9) + 1;
337 ans = perfpow (np, n, d, g, count, neg);
338
339 ret:
340 TMP_FREE;
341 return ans;
342}