[算法] 幂塔

前置知识

  • 欧拉函数
  • 扩展欧拉定理
  • 树状数组 (动态维护)
  • 快速幂

相关题目

题目形式

给定长度为 nn 的序列 aa 和模数 mm, 求

a1a2a3anmodma_1^{a_2^{a_3^{\dots ^{a_n}}}} \mod m

mm 可能是给定的也可能是询问中动态变化的
aa 可能是常量数组也可能是没有规律的数组, 甚至可以是询问中动态变化的
区间 可以是整个序列也可以是询问中指定 [l,r][l, r]

解决方法

根据扩展欧拉定理:

abmodm={abmodϕ(m)if gcd(a,m)=1abif gcd(a,m)1 and b<ϕ(m)abmodϕ(m)+ϕ(m)if gcd(a,m)1 and bϕ(m)a^b \mod m = \begin{cases} a^{b \mod \phi(m)} & \text{if } \gcd(a, m) = 1 \\ a^b & \text{if } \gcd(a, m) \neq 1 \text{ and } b < \phi(m) \\ a^{b \mod \phi(m) + \phi(m)} & \text{if } \gcd(a, m) \neq 1 \text{ and } b \geq \phi(m) \end{cases}

我们知道一个幂运算的模可以用快速幂 + 指数对模的欧拉函数取模来计算
由于欧拉函数每两次运算至少会让模数 m 减少一半, 因此最多只需要计算 O(log2m)O(\log_2 m) 次欧拉函数, 之后 ϕ(m)\phi(m) 就会变成 1, 不用继续计算
因此对于一个区间询问, 暴力递归地进行计算实际上只用计算幂塔的前 O(log2m)O(\log_2 m) 层, 因此复杂度是没有问题的

适用于拓展欧拉定理的快速幂

让返回结果自动判断是否大于等于模数 mm, 并自动根据扩展欧拉定理为指数 (返回值) 加上 mm

i64 fpow(i64 a, i64 b, i64 p) {
    if (b == 0) return 1;
    i64 res = 1;
    bool flag = false;
    if (a >= p) a %= p, flag = true;
    for (; b; b >>= 1) {
        if (b & 1) {
            res *= a;
            if (res >= p) res %= p, flag = true;
        }
        if (b > 1) {
            a *= a;
            if (a >= p) a %= p, flag = true;
        }
    }
    return res + flag * p;
}

某些题目因为幂塔上大部分的数固定且连续, 而且询问次数比较多, 可能需要用到分块预处理的 O(1)O(1) 快速幂来计算

欧拉函数的计算

一般有两种

  1. 线性筛法, 适用于模数会随询问变化, 且模数小于 10710^7
constexpr int N = 20'000'000;
 
std::vector<int> pri;
std::bitset<N + 5> not_prime;
std::array<int, N + 5> phi;
 
int main() {
    // ...
    phi[1] = 1;
    for (int i = 2; i <= N; ++i) {
        if (not not_prime[i]) {
            pri.push_back(i);
            phi[i] = i - 1;
        }
        for (auto p : pri) {
            if (1ll * i * p > N) break;
            not_prime[i * p] = true;
            if (i % p == 0) {
                phi[i * p] = phi[i] * p;
                break;
            }
            phi[i * p] = phi[i] * phi[p];
        }
    }
    // ...
}
  1. 打表欧拉函数法, 预处理出需要用的欧拉函数值, 适用于模数不变且较大的情况
i64 get_phi(i64 n) {
    i64 res = n;
    for (i64 i = 2; i * i <= n; ++i) {
        if (n % i == 0) {
            res = res / i * (i - 1);
            while (n % i == 0) n /= i;
        }
    }
    if (n > 1) res = res / n * (n - 1);
    return res;
}
 
void solve() {
    // ...
    std::vector<i64> phi; [&phi, m]() mutable {
        while (m > 1) {
            phi.push_back(m);
            m = get_phi(m);
        }
        phi.push_back(1ll);
    }();
    // ...
}

幂塔的递归计算

一般是这样

auto calc = [&](this auto&& calc, int l, int r, int mod_i) -> i64 {
    if (phi[mod_i] == 1 || l > r) return 1ll;
    return fpow(a[l], calc(l + 1, r, mod_i + 1), phi[mod_i]);
};

根据题目不同可能递归的终止条件不同, 但是 ϕ=1\phi = 1 一般是必要的一个条件

例题代码展示

  1. 炸脖龙 I
#include <iostream>
#include <vector>
#include <bitset>
#include <array>
 
using i64 = long long;
constexpr int N = 20'000'000;
 
std::vector<int> pri;
std::bitset<N + 5> not_prime;
std::array<int, N + 5> phi;
 
struct BIT {
    std::vector<i64>& a;
    int n;
    BIT(std::vector<i64>& i) : a(i), n(i.size() - 1) {
        for (int i = 1; i <= n; ++i) {
            int j = i + (i & -i);
            if (j <= n) a[j] += a[i];
        }
    }
 
    void radd(int l, int r, i64 x) {
        auto add = [&](int i, i64 x) {
            for (; i <= n; i += i & -i) a[i] += x;
        };
        add(l, x), add(r + 1, -x);
    }
 
    i64 que(int i) {
        i64 res = 0;
        for (; i; i -= i & -i) res += a[i];
        return res;
    }
};
 
i64 fpow(i64 a, i64 b, i64 p) {
    i64 res = 1;
    bool flag = false;
    if (a >= p) a %= p, flag = true;
    for (; b; b >>= 1) {
        if (b & 1) {
            res *= a;
            if (res >= p) res %= p, flag = true;
        }
        if (b > 1) {
            a *= a;
            if (a >= p) a %= p, flag = true;
        }
    }
    return res + flag * p;
}
 
void solve() {
    int n, m;
    std::cin >> n >> m;
 
    std::vector<i64> a(n + 1, 0);
    for (int i = 1; i <= n; ++i) std::cin >> a[i];
    for (int i = n; i >= 1; --i) a[i] -= a[i - 1];
 
    BIT t(a);
 
    auto calc = [&](this auto&& calc, int l, int r, int p) -> i64 {
        if (p == 1 || l > r) return 1ll;
        return fpow(t.que(l), calc(l + 1, r, phi[p]), p);
    };
 
    for (int _ = 0; _ < m; ++_) {
        int op, l, r;
        i64 v;
        std::cin >> op >> l >> r >> v;
 
        if (op == 1) t.radd(l, r, v);
        if (op == 2) std::cout << calc(l, r, v) % v << "\n";
    }
}
 
int main() {
    std::cin.tie(nullptr) -> sync_with_stdio(false);
 
    phi[1] = 1;
    for (int i = 2; i <= N; ++i) {
        if (not not_prime[i]) {
            pri.push_back(i);
            phi[i] = i - 1;
        }
        for (auto p : pri) {
            if (1ll * i * p > N) break;
            not_prime[i * p] = true;
            if (i % p == 0) {
                phi[i * p] = phi[i] * p;
                break;
            }
            phi[i * p] = phi[i] * phi[p];
        }
    }
 
    int t = 1;
    // std::cin >> t;
 
    while (t--) solve();
}
  1. Power Tower
#include <bits/stdc++.h>
 
using i64 = long long;
 
i64 get_phi(i64 n) {
    i64 res = n;
    for (i64 i = 2; i * i <= n; ++i) {
        if (n % i == 0) {
            res = res / i * (i - 1);
            while (n % i == 0) n /= i;
        }
    }
    if (n > 1) res = res / n * (n - 1);
    return res;
}
 
i64 fpow(i64 a, i64 b, i64 p) {
    i64 res = 1;
    bool flag = false;
    if (a >= p) a %= p, flag = true;
    for (; b; b >>= 1) {
        if (b & 1) {
            res *= a;
            if (res >= p) res %= p, flag = true;
        }
        if (b > 1) {
            a *= a;
            if (a >= p && b > 1) a %= p, flag = true;
        }
    }
    return res + flag * p;
}
 
void solve() {
    int n;
    std::cin >> n;
 
    i64 m;
    std::cin >> m;
 
    std::vector<i64> phi; [&phi, m]() mutable {
        while (m > 1) {
            phi.push_back(m);
            m = get_phi(m);
        }
        phi.push_back(1ll);
    }();
 
    std::vector<i64> a(n);
    for (auto& v : a) std::cin >> v;
 
    auto calc = [&](this auto&& calc, int l, int r, int mod_i) -> i64 {
        if (phi[mod_i] == 1 || l > r) return 1ll;
        return fpow(a[l], calc(l + 1, r, mod_i + 1), phi[mod_i]);
    };
 
    int q;
    std::cin >> q;
    
    for (int _ = 0; _ < q; ++_) {
        int l, r;
        std::cin >> l >> r;
 
        std::cout << calc(l - 1, r - 1, 0) % m << "\n";
    }
}
 
int main() {
    std::cin.tie(nullptr) -> sync_with_stdio(false);
 
    int t = 1;
    // std::cin >> t;
 
    while (t--) solve();
}
  1. 上帝与集合的正确用法
#include <iostream>
#include <vector>
#include <bitset>
#include <array>
 
using i64 = long long;
constexpr int N = 10'000'000;
 
std::vector<int> pri;
std::bitset<N + 5> not_prime;
std::array<int, N + 5> phi;
 
i64 fpow(i64 a, i64 b, i64 p) {
    i64 res = 1;
    bool flag = false;
    if (a >= p) a %= p, flag = true;
    for (; b; b >>= 1) {
        if (b & 1) {
            res *= a;
            if (res >= p) res %= p, flag = true;
        }
        if (b > 1) {
            a *= a;
            if (a >= p) a %= p, flag = true;
        }
    }
    return res + flag * p;
}
 
void solve() {
    auto calc = [&](this auto&& calc, int p) -> i64 {
        if (p == 1) return 1ll;
        return fpow(2, calc(phi[p]), p);
    };
 
    int p;
    std::cin >> p;
 
    std::cout << calc(p) % p << "\n";
}
 
int main() {
    std::cin.tie(nullptr) -> sync_with_stdio(false);
 
    phi[1] = 1;
    for (int i = 2; i <= N; ++i) {
        if (not not_prime[i]) {
            pri.push_back(i);
            phi[i] = i - 1;
        }
        for (auto p : pri) {
            if (1ll * i * p > N) break;
            not_prime[i * p] = true;
            if (i % p == 0) {
                phi[i * p] = phi[i] * p;
                break;
            }
            phi[i * p] = phi[i] * phi[p];
        }
    }
 
    int t = 1;
    std::cin >> t;
 
    while (t--) solve();
}
  1. 乘方塔
#include <iostream>
#include <vector>
#include <bitset>
#include <array>
 
using i64 = long long;
using u64 = unsigned long long;
constexpr int N = 1'000'000;
 
std::vector<int> pri;
std::bitset<N + 5> not_prime;
std::array<int, N + 5> phi;
 
i64 fpow(i64 a, i64 b, i64 p) {
    i64 res = 1;
    bool flag = false;
    if (a >= p) a %= p, flag = true;
    for (; b; b >>= 1) {
        if (b & 1) {
            res *= a;
            if (res >= p) res %= p, flag = true;
        }
        if (b > 1) {
            a *= a;
            if (a >= p) a %= p, flag = true;
        }
    }
    return res + flag * p;
}
 
void solve() {
    i64 a, p;
    u64 b;
    std::cin >> a >> b >> p;
 
    auto calc = [&](this auto&& calc, u64 b, int p) -> i64 {
        if (p == 1 || b == 0) return 1ll;
        return fpow(a, calc(b - 1, phi[p]), p);
    };
 
    std::cout << calc(b, p) % p << "\n";
}
 
int main() {
    std::cin.tie(nullptr) -> sync_with_stdio(false);
 
    phi[1] = 1;
    for (int i = 2; i <= N; ++i) {
        if (not not_prime[i]) {
            pri.push_back(i);
            phi[i] = i - 1;
        }
        for (auto p : pri) {
            if (1ll * i * p > N) break;
            not_prime[i * p] = true;
            if (i % p == 0) {
                phi[i * p] = phi[i] * p;
                break;
            }
            phi[i * p] = phi[i] * phi[p];
        }
    }
 
    int t = 1;
    // std::cin >> t;
 
    while (t--) solve();
}
  1. 相逢是问候
    使用了分块优化快速幂以及并查集维护欧拉函数归1的块, 使得最多整个序列的每一个元素最多被访问 O(logp)O(\log p)
#include <bits/stdc++.h>
 
using i64 = long long;
 
constexpr int SZ = 14500;
i64 s1[65][SZ + 1], s2[65][SZ + 1];
 
i64 get_phi(i64 n) {
    i64 res = n;
    for (i64 i = 2; i * i <= n; ++i) {
        if (n % i == 0) {
            res = res / i * (i - 1);
            while (n % i == 0) n /= i;
        }
    }
    if (n > 1) res = res / n * (n - 1);
    return res;
}
 
struct DSU {
    std::vector<int> fa;
 
    DSU(int n) {
        fa.resize(n + 2);
        std::iota(fa.begin(), fa.end(), 0);
    }
    int get(int x) {
        while (x != fa[x]) {
            x = fa[x] = fa[fa[x]];
        }
        return x;
    }
    bool merge(int x, int y) {
        x = get(x), y = get(y);
        if (x == y) return false;
        if (x < y) std::swap(x, y);
        fa[y] = x;
        return true;
    }
};
 
struct BIT {
    std::vector<i64>& a;
    int n;
    i64 p;
    BIT(std::vector<i64>& i, i64 p) : a(i), n(i.size() - 1), p(p) {
        for (int i = 1; i <= n; ++i) {
            int j = i + (i & -i);
            if (j <= n) a[j] = (a[j] + a[i]) % p;
        }
    }
 
    void add(int i, i64 x) {
        x = (x % p + p) % p;
        for (; i <= n; i += i & -i) a[i] = (a[i] + x) % p;
    }
 
    i64 rque(int l, int r) {
        auto que = [&](int i){
            i64 res = 0;
            for (; i; i -= i & -i) res = (res + a[i]) % p;
            return res;
        };
        return (que(r) - que(l - 1) + p) % p;
    }
};
 
void solve() {
    int n, m, p, c;
    std::cin >> n >> m >> p >> c;
    
    std::vector<i64> phi; [&phi, p]() mutable {
        while (p > 1) {
            phi.push_back(p);
            p = get_phi(p);
        }
        phi.push_back(1ll);
    }();
 
    for (int i = 0; i < phi.size(); ++i) {
        i64 mod = phi[i];
        s1[i][0] = 1;
        for (int j = 1; j <= SZ; ++j) {
            s1[i][j] = s1[i][j - 1] * c;
            if (s1[i][j] >= mod) s1[i][j] = s1[i][j] % mod + mod;
        }
        i64 step = s1[i][SZ];
        s2[i][0] = 1;
        for (int j = 1; j <= SZ; ++j) {
            s2[i][j] = s2[i][j - 1] * step;
            if (s2[i][j] >= mod) s2[i][j] = s2[i][j] % mod + mod;
        }
    }
 
    std::vector<i64> a(n + 1, 0);
    for (int i = 1; i <= n; ++i) std::cin >> a[i];
    
    auto calc = [&](auto&& calc, int i, int d, int mod_i) -> i64 {
        if (mod_i >= phi.size()) return 1ll;
        if (d == 0) return a[i] >= phi[mod_i] ? a[i] % phi[mod_i] + phi[mod_i] : a[i];
        
        i64 expo = calc(calc, i, d - 1, mod_i + 1);
        i64 val = s1[mod_i][expo % SZ] * s2[mod_i][expo / SZ];
        if (val >= phi[mod_i]) val = val % phi[mod_i] + phi[mod_i];
        return val;
    };
 
    std::vector b = a;
    for (auto& x : b) x %= p;
 
    std::vector<int> dp(n + 1, 0);
    std::vector<i64> cur(n + 1);
    for (int i = 1; i <= n; ++i) cur[i] = a[i] % p;
 
    BIT t(b, p);
    DSU dsu(n);
 
    for (int _ = 0; _ < m; ++_) {
        int op, l, r;
        std::cin >> op >> l >> r;
 
        if (op == 0) {
            for (int i = dsu.get(l); i <= r; i = dsu.get(i + 1)) {
                t.add(i, -cur[i]);
                dp[i]++;
                cur[i] = calc(calc, i, dp[i], 0) % p;
                t.add(i, cur[i]);
                if (dp[i] >= phi.size()) dsu.merge(i, i + 1);
            }
        }
        else if (op == 1) {
            std::cout << t.rque(l, r) << "\n";
        }
    }
}
 
int main() {
    std::cin.tie(nullptr) -> sync_with_stdio(false);
 
    int t = 1;
    // std::cin >> t;
 
    while (t--) solve();
}