Y
Y
Yura Khlyan2015-07-31 14:49:40
Python
Yura Khlyan, 2015-07-31 14:49:40

Optimal exponentiation modulo algorithm. How to improve?

Good day.
There is a task:
Calculate the last 12 characters of the following expression (** - exponentiation): A * B**C + D
Where, 1 <= A, B, C, D <= 10**9
Here is my implementation in Python:

n = int(input()) #number of test cases
module = 10 ** 12

def pow_l(x, n):
    result = 1
    while n != 0:
        if n % 2 != 0:
            result *= x
            result %= module
            n -= 1
        else:
            x *= x
            x %= module
            n /= 2
    return result

for i in range(n):
    a, b, c, d = (int(num) for num in input().split())
    prod = (a * pow_l(b, c) + d) % module
    print (str(prod).zfill(12))

But, alas, this algorithm is too slow. What can you advise?
I will be very grateful for your help.

Answer the question

In order to leave comments, you need to log in

3 answer(s)
V
Vitaly, 2015-07-31
@denshi

Python already has a function for exponentiation modulo: pow(b, c, module). It is unlikely that you will write a faster function using the built-in language tools.

V
Vladimir Martyanov, 2015-07-31
@vilgeforce

See, for example, the implementation in openSSL. You can also look at MIRACL, the sorts are also open.

⚡ Kotobotov ⚡, 2015-07-31
@angrySCV

you already have a fast implementation, the only thing you need to do is to optimize long operations (such as conditional branching is one of the most expensive operations), and, if possible, get rid of unnecessary operations.
for example, instead of checking (dividing by 2 without a remainder), it is enough to check the last bit of the number equal to zero.
division by 2 can also be replaced by a shift.
it’s better to just get rid of conditional operations, replacing them with unconditional ones -> because you have an operation x * x everywhere -> this part of the operations can be combined
and the remaining operation with N can be replaced with a double one (where the bit remainder is the last bit in the number N will be equals either 1 or 0 depending on the parity of the number).
example: n= (n>>1)*(bit rest) + (n--)*(! bit rest)
such optimization is guaranteed to speed up the work.

Didn't find what you were looking for?

Ask your question

Ask a Question

731 491 924 answers to any question