[算法] 幂塔
前置知识
- 欧拉函数
- 扩展欧拉定理
- 树状数组 (动态维护)
- 快速幂
相关题目
题目形式
给定长度为 的序列 和模数 , 求
可能是给定的也可能是询问中动态变化的
可能是常量数组也可能是没有规律的数组, 甚至可以是询问中动态变化的
区间 可以是整个序列也可以是询问中指定
解决方法
根据扩展欧拉定理:
我们知道一个幂运算的模可以用快速幂 + 指数对模的欧拉函数取模来计算
由于欧拉函数每两次运算至少会让模数 m 减少一半, 因此最多只需要计算 次欧拉函数, 之后 就会变成 1, 不用继续计算
因此对于一个区间询问, 暴力递归地进行计算实际上只用计算幂塔的前 层, 因此复杂度是没有问题的
适用于拓展欧拉定理的快速幂
让返回结果自动判断是否大于等于模数 , 并自动根据扩展欧拉定理为指数 (返回值) 加上
i64 fpow(i64 a, i64 b, i64 p) {
if (b == 0) return 1;
i64 res = 1;
bool flag = false;
if (a >= p) a %= p, flag = true;
for (; b; b >>= 1) {
if (b & 1) {
res *= a;
if (res >= p) res %= p, flag = true;
}
if (b > 1) {
a *= a;
if (a >= p) a %= p, flag = true;
}
}
return res + flag * p;
}某些题目因为幂塔上大部分的数固定且连续, 而且询问次数比较多, 可能需要用到分块预处理的 快速幂来计算
欧拉函数的计算
一般有两种
- 线性筛法, 适用于模数会随询问变化, 且模数小于
constexpr int N = 20'000'000;
std::vector<int> pri;
std::bitset<N + 5> not_prime;
std::array<int, N + 5> phi;
int main() {
// ...
phi[1] = 1;
for (int i = 2; i <= N; ++i) {
if (not not_prime[i]) {
pri.push_back(i);
phi[i] = i - 1;
}
for (auto p : pri) {
if (1ll * i * p > N) break;
not_prime[i * p] = true;
if (i % p == 0) {
phi[i * p] = phi[i] * p;
break;
}
phi[i * p] = phi[i] * phi[p];
}
}
// ...
}- 打表欧拉函数法, 预处理出需要用的欧拉函数值, 适用于模数不变且较大的情况
i64 get_phi(i64 n) {
i64 res = n;
for (i64 i = 2; i * i <= n; ++i) {
if (n % i == 0) {
res = res / i * (i - 1);
while (n % i == 0) n /= i;
}
}
if (n > 1) res = res / n * (n - 1);
return res;
}
void solve() {
// ...
std::vector<i64> phi; [&phi, m]() mutable {
while (m > 1) {
phi.push_back(m);
m = get_phi(m);
}
phi.push_back(1ll);
}();
// ...
}幂塔的递归计算
一般是这样
auto calc = [&](this auto&& calc, int l, int r, int mod_i) -> i64 {
if (phi[mod_i] == 1 || l > r) return 1ll;
return fpow(a[l], calc(l + 1, r, mod_i + 1), phi[mod_i]);
};根据题目不同可能递归的终止条件不同, 但是 一般是必要的一个条件
例题代码展示
#include <iostream>
#include <vector>
#include <bitset>
#include <array>
using i64 = long long;
constexpr int N = 20'000'000;
std::vector<int> pri;
std::bitset<N + 5> not_prime;
std::array<int, N + 5> phi;
struct BIT {
std::vector<i64>& a;
int n;
BIT(std::vector<i64>& i) : a(i), n(i.size() - 1) {
for (int i = 1; i <= n; ++i) {
int j = i + (i & -i);
if (j <= n) a[j] += a[i];
}
}
void radd(int l, int r, i64 x) {
auto add = [&](int i, i64 x) {
for (; i <= n; i += i & -i) a[i] += x;
};
add(l, x), add(r + 1, -x);
}
i64 que(int i) {
i64 res = 0;
for (; i; i -= i & -i) res += a[i];
return res;
}
};
i64 fpow(i64 a, i64 b, i64 p) {
i64 res = 1;
bool flag = false;
if (a >= p) a %= p, flag = true;
for (; b; b >>= 1) {
if (b & 1) {
res *= a;
if (res >= p) res %= p, flag = true;
}
if (b > 1) {
a *= a;
if (a >= p) a %= p, flag = true;
}
}
return res + flag * p;
}
void solve() {
int n, m;
std::cin >> n >> m;
std::vector<i64> a(n + 1, 0);
for (int i = 1; i <= n; ++i) std::cin >> a[i];
for (int i = n; i >= 1; --i) a[i] -= a[i - 1];
BIT t(a);
auto calc = [&](this auto&& calc, int l, int r, int p) -> i64 {
if (p == 1 || l > r) return 1ll;
return fpow(t.que(l), calc(l + 1, r, phi[p]), p);
};
for (int _ = 0; _ < m; ++_) {
int op, l, r;
i64 v;
std::cin >> op >> l >> r >> v;
if (op == 1) t.radd(l, r, v);
if (op == 2) std::cout << calc(l, r, v) % v << "\n";
}
}
int main() {
std::cin.tie(nullptr) -> sync_with_stdio(false);
phi[1] = 1;
for (int i = 2; i <= N; ++i) {
if (not not_prime[i]) {
pri.push_back(i);
phi[i] = i - 1;
}
for (auto p : pri) {
if (1ll * i * p > N) break;
not_prime[i * p] = true;
if (i % p == 0) {
phi[i * p] = phi[i] * p;
break;
}
phi[i * p] = phi[i] * phi[p];
}
}
int t = 1;
// std::cin >> t;
while (t--) solve();
}#include <bits/stdc++.h>
using i64 = long long;
i64 get_phi(i64 n) {
i64 res = n;
for (i64 i = 2; i * i <= n; ++i) {
if (n % i == 0) {
res = res / i * (i - 1);
while (n % i == 0) n /= i;
}
}
if (n > 1) res = res / n * (n - 1);
return res;
}
i64 fpow(i64 a, i64 b, i64 p) {
i64 res = 1;
bool flag = false;
if (a >= p) a %= p, flag = true;
for (; b; b >>= 1) {
if (b & 1) {
res *= a;
if (res >= p) res %= p, flag = true;
}
if (b > 1) {
a *= a;
if (a >= p && b > 1) a %= p, flag = true;
}
}
return res + flag * p;
}
void solve() {
int n;
std::cin >> n;
i64 m;
std::cin >> m;
std::vector<i64> phi; [&phi, m]() mutable {
while (m > 1) {
phi.push_back(m);
m = get_phi(m);
}
phi.push_back(1ll);
}();
std::vector<i64> a(n);
for (auto& v : a) std::cin >> v;
auto calc = [&](this auto&& calc, int l, int r, int mod_i) -> i64 {
if (phi[mod_i] == 1 || l > r) return 1ll;
return fpow(a[l], calc(l + 1, r, mod_i + 1), phi[mod_i]);
};
int q;
std::cin >> q;
for (int _ = 0; _ < q; ++_) {
int l, r;
std::cin >> l >> r;
std::cout << calc(l - 1, r - 1, 0) % m << "\n";
}
}
int main() {
std::cin.tie(nullptr) -> sync_with_stdio(false);
int t = 1;
// std::cin >> t;
while (t--) solve();
}#include <iostream>
#include <vector>
#include <bitset>
#include <array>
using i64 = long long;
constexpr int N = 10'000'000;
std::vector<int> pri;
std::bitset<N + 5> not_prime;
std::array<int, N + 5> phi;
i64 fpow(i64 a, i64 b, i64 p) {
i64 res = 1;
bool flag = false;
if (a >= p) a %= p, flag = true;
for (; b; b >>= 1) {
if (b & 1) {
res *= a;
if (res >= p) res %= p, flag = true;
}
if (b > 1) {
a *= a;
if (a >= p) a %= p, flag = true;
}
}
return res + flag * p;
}
void solve() {
auto calc = [&](this auto&& calc, int p) -> i64 {
if (p == 1) return 1ll;
return fpow(2, calc(phi[p]), p);
};
int p;
std::cin >> p;
std::cout << calc(p) % p << "\n";
}
int main() {
std::cin.tie(nullptr) -> sync_with_stdio(false);
phi[1] = 1;
for (int i = 2; i <= N; ++i) {
if (not not_prime[i]) {
pri.push_back(i);
phi[i] = i - 1;
}
for (auto p : pri) {
if (1ll * i * p > N) break;
not_prime[i * p] = true;
if (i % p == 0) {
phi[i * p] = phi[i] * p;
break;
}
phi[i * p] = phi[i] * phi[p];
}
}
int t = 1;
std::cin >> t;
while (t--) solve();
}#include <iostream>
#include <vector>
#include <bitset>
#include <array>
using i64 = long long;
using u64 = unsigned long long;
constexpr int N = 1'000'000;
std::vector<int> pri;
std::bitset<N + 5> not_prime;
std::array<int, N + 5> phi;
i64 fpow(i64 a, i64 b, i64 p) {
i64 res = 1;
bool flag = false;
if (a >= p) a %= p, flag = true;
for (; b; b >>= 1) {
if (b & 1) {
res *= a;
if (res >= p) res %= p, flag = true;
}
if (b > 1) {
a *= a;
if (a >= p) a %= p, flag = true;
}
}
return res + flag * p;
}
void solve() {
i64 a, p;
u64 b;
std::cin >> a >> b >> p;
auto calc = [&](this auto&& calc, u64 b, int p) -> i64 {
if (p == 1 || b == 0) return 1ll;
return fpow(a, calc(b - 1, phi[p]), p);
};
std::cout << calc(b, p) % p << "\n";
}
int main() {
std::cin.tie(nullptr) -> sync_with_stdio(false);
phi[1] = 1;
for (int i = 2; i <= N; ++i) {
if (not not_prime[i]) {
pri.push_back(i);
phi[i] = i - 1;
}
for (auto p : pri) {
if (1ll * i * p > N) break;
not_prime[i * p] = true;
if (i % p == 0) {
phi[i * p] = phi[i] * p;
break;
}
phi[i * p] = phi[i] * phi[p];
}
}
int t = 1;
// std::cin >> t;
while (t--) solve();
}- 相逢是问候
使用了分块优化快速幂以及并查集维护欧拉函数归1的块, 使得最多整个序列的每一个元素最多被访问 次
#include <bits/stdc++.h>
using i64 = long long;
constexpr int SZ = 14500;
i64 s1[65][SZ + 1], s2[65][SZ + 1];
i64 get_phi(i64 n) {
i64 res = n;
for (i64 i = 2; i * i <= n; ++i) {
if (n % i == 0) {
res = res / i * (i - 1);
while (n % i == 0) n /= i;
}
}
if (n > 1) res = res / n * (n - 1);
return res;
}
struct DSU {
std::vector<int> fa;
DSU(int n) {
fa.resize(n + 2);
std::iota(fa.begin(), fa.end(), 0);
}
int get(int x) {
while (x != fa[x]) {
x = fa[x] = fa[fa[x]];
}
return x;
}
bool merge(int x, int y) {
x = get(x), y = get(y);
if (x == y) return false;
if (x < y) std::swap(x, y);
fa[y] = x;
return true;
}
};
struct BIT {
std::vector<i64>& a;
int n;
i64 p;
BIT(std::vector<i64>& i, i64 p) : a(i), n(i.size() - 1), p(p) {
for (int i = 1; i <= n; ++i) {
int j = i + (i & -i);
if (j <= n) a[j] = (a[j] + a[i]) % p;
}
}
void add(int i, i64 x) {
x = (x % p + p) % p;
for (; i <= n; i += i & -i) a[i] = (a[i] + x) % p;
}
i64 rque(int l, int r) {
auto que = [&](int i){
i64 res = 0;
for (; i; i -= i & -i) res = (res + a[i]) % p;
return res;
};
return (que(r) - que(l - 1) + p) % p;
}
};
void solve() {
int n, m, p, c;
std::cin >> n >> m >> p >> c;
std::vector<i64> phi; [&phi, p]() mutable {
while (p > 1) {
phi.push_back(p);
p = get_phi(p);
}
phi.push_back(1ll);
}();
for (int i = 0; i < phi.size(); ++i) {
i64 mod = phi[i];
s1[i][0] = 1;
for (int j = 1; j <= SZ; ++j) {
s1[i][j] = s1[i][j - 1] * c;
if (s1[i][j] >= mod) s1[i][j] = s1[i][j] % mod + mod;
}
i64 step = s1[i][SZ];
s2[i][0] = 1;
for (int j = 1; j <= SZ; ++j) {
s2[i][j] = s2[i][j - 1] * step;
if (s2[i][j] >= mod) s2[i][j] = s2[i][j] % mod + mod;
}
}
std::vector<i64> a(n + 1, 0);
for (int i = 1; i <= n; ++i) std::cin >> a[i];
auto calc = [&](auto&& calc, int i, int d, int mod_i) -> i64 {
if (mod_i >= phi.size()) return 1ll;
if (d == 0) return a[i] >= phi[mod_i] ? a[i] % phi[mod_i] + phi[mod_i] : a[i];
i64 expo = calc(calc, i, d - 1, mod_i + 1);
i64 val = s1[mod_i][expo % SZ] * s2[mod_i][expo / SZ];
if (val >= phi[mod_i]) val = val % phi[mod_i] + phi[mod_i];
return val;
};
std::vector b = a;
for (auto& x : b) x %= p;
std::vector<int> dp(n + 1, 0);
std::vector<i64> cur(n + 1);
for (int i = 1; i <= n; ++i) cur[i] = a[i] % p;
BIT t(b, p);
DSU dsu(n);
for (int _ = 0; _ < m; ++_) {
int op, l, r;
std::cin >> op >> l >> r;
if (op == 0) {
for (int i = dsu.get(l); i <= r; i = dsu.get(i + 1)) {
t.add(i, -cur[i]);
dp[i]++;
cur[i] = calc(calc, i, dp[i], 0) % p;
t.add(i, cur[i]);
if (dp[i] >= phi.size()) dsu.merge(i, i + 1);
}
}
else if (op == 1) {
std::cout << t.rque(l, r) << "\n";
}
}
}
int main() {
std::cin.tie(nullptr) -> sync_with_stdio(false);
int t = 1;
// std::cin >> t;
while (t--) solve();
}