:heavy_check_mark: ModSqrt
(math/mod_sqrt.hpp)

Depends on

Verified with

Code

#include<tuple>
#include"mod_pow.hpp"
#include"../util/random_gen.hpp"

/**
 * @brief ModSqrt
 */

long long mod_sqrt(long long a,long long mod){
    if(mod==2||a==0)return a;
    if(mod_pow(a,(mod-1)/2,mod)!=1)return -1;
    if(mod%4==3)return mod_pow(a,(mod+1)/4,mod);
    long long q=(mod-1),s=0;
    while(q%2==0)q/=2,s++;
    long long z=1;
    RandomNumberGenerator rnd;
    while(mod_pow(z=rnd(0,mod),(mod-1)/2,mod)!=mod-1);
    long long c=mod_pow(z,q,mod),t=mod_pow(a,q,mod),r=mod_pow(a,(q+1)/2,mod),m=s;
    while(m>1){
        if(mod_pow(t,1<<(m-2),mod)==1)(c*=c)%=mod,--m;
        else std::tie(c,t,r,m)=std::make_tuple(c*c%mod,c*c%mod*t%mod,c*r%mod,m-1);
    }
    return r%mod;
}
#line 1 "math/mod_sqrt.hpp"
#include<tuple>
#line 1 "math/mod_pow.hpp"
/**
 * @brief (x^y)%mod
 */

long long mod_pow(long long x,long long y,long long mod){
    long long ret=1;
    while(y>0) {
        if(y&1)(ret*=x)%=mod;
        (x*=x)%=mod;
        y>>=1;
    }
    return ret;
}
#line 2 "util/random_gen.hpp"
#include<random>
#include<chrono>

struct RandomNumberGenerator {
    std::mt19937 mt;
    RandomNumberGenerator() : mt(std::chrono::steady_clock::now().time_since_epoch().count()) {}
    int operator()(int a, int b) { // [a, b)
        std::uniform_int_distribution< int > dist(a, b - 1);
        return dist(mt);
    }

    int operator()(int b) { // [0, b)
        return (*this)(0, b);
    }
};
#line 4 "math/mod_sqrt.hpp"

/**
 * @brief ModSqrt
 */

long long mod_sqrt(long long a,long long mod){
    if(mod==2||a==0)return a;
    if(mod_pow(a,(mod-1)/2,mod)!=1)return -1;
    if(mod%4==3)return mod_pow(a,(mod+1)/4,mod);
    long long q=(mod-1),s=0;
    while(q%2==0)q/=2,s++;
    long long z=1;
    RandomNumberGenerator rnd;
    while(mod_pow(z=rnd(0,mod),(mod-1)/2,mod)!=mod-1);
    long long c=mod_pow(z,q,mod),t=mod_pow(a,q,mod),r=mod_pow(a,(q+1)/2,mod),m=s;
    while(m>1){
        if(mod_pow(t,1<<(m-2),mod)==1)(c*=c)%=mod,--m;
        else std::tie(c,t,r,m)=std::make_tuple(c*c%mod,c*c%mod*t%mod,c*r%mod,m-1);
    }
    return r%mod;
}
Back to top page