K
K
keddad2019-07-25 23:51:07
Python
keddad, 2019-07-25 23:51:07

How to speed up this Python code as much as possible?

There is the following code that solves this problem

from array import array

n, m = map(int, input().split())

parent, weight, rank = array('i', [-1 for _ in range(n)]), array('I', [0 for _ in range(n)]), array('I', [1 for _ in
                                                                                                          range(n)])


def find_set(v: int) -> int:
    if parent[v] == -1:
        parent[v] = v
        return v
    if v == parent[v]:
        return v
    parent[v] = find_set(parent[v])
    weight[parent[v]] += weight[v]
    weight[v] = 0
    return parent[v]


def union_sets(a: int, b: int, cost: int) -> None:
    a = find_set(a)
    b = find_set(b)
    if a != b:
        if rank[a] < rank[b]:
            a, b = b, a
        parent[b] = a
        weight[a] += cost
        weight[a] += weight[b]
        weight[b] = 0
        if rank[a] == rank[b]:
            rank[a] += 1
    else:
        weight[a] += cost


with open("input.txt", "r") as inp:
    with open("output.txt", "w") as out:
        inp.__next__()
        for line in inp:
            st = line.split()
            if len(st) != 4:
                out.write(str(weight[find_set(int(st[1]) - 1)]) + "\n")
            else:
                union_sets(int(st[1]) - 1, int(st[2]) - 1, int(st[3]))

Unfortunately, it is somewhat out of time. Theoretically, one could just write the same algorithm in C++, but we are not looking for easy ways! What other optimizations can be applied to the code above to reduce the running time with large amounts of data?

Answer the question

In order to leave comments, you need to log in

3 answer(s)
K
keddad, 2019-07-26
@keddad

Okay, in the final version I got rid of recursion, typing and Arrays. I got the main profit from recursion, of course. This was not enough to solve the problem, but the code was noticeably faster.

n, m = map(int, input().split())

parent, weight, rank = [-1 for _ in range(n)], [0 for _ in range(n)], [1 for _ in range(n)]


def find_set(v):
    while parent[v] != -1 and parent[v] != v:
        parent[v] = parent[parent[v]]
        weight[parent[v]] += weight[v]
        weight[v] = 0
        v = parent[v]
    if parent[v] == -1:
        parent[v] = v
    return v


def union_sets(a, b, cost):
    a = find_set(a)
    b = find_set(b)
    if a != b:
        if rank[a] < rank[b]:
            a, b = b, a
        parent[b] = a
        weight[a] += cost
        weight[a] += weight[b]
        weight[b] = 0
        if rank[a] == rank[b]:
            rank[a] += 1
    else:
        weight[a] += cost


def main():
    with open("input.txt", "r") as inp:
        with open("output.txt", "w") as out:
            inp.__next__()
            for line in inp:
                st = line.split()
                if len(st) != 4:
                    out.write(str(weight[find_set(int(st[1]) - 1)]) + "\n")
                else:
                    union_sets(int(st[1]) - 1, int(st[2]) - 1, int(st[3]))


main()

R
Roman Kitaev, 2019-07-26
@deliro

1. The [-1 for _ in range(n)] construct already creates a list. Further this list is simply thrown out and array.array is generated. In total, 6 potentially huge collections are immediately generated in one line. Either you can change it to (-1 for _ in range(n)), or abandon array.array, its benefits here are questionable:

In [3]: a = array("I", range(10000))                                                                                    

In [4]: b = list(range(10000))                                                                                          

In [5]: %timeit sum(a)                                                                                                  
206 µs ± 6.03 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [6]: %timeit sum(b)                                                                                                  
69.3 µs ± 367 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [7]: %timeit a[7777]                                                                                                 
49.5 ns ± 0.564 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)

In [8]: %timeit b[7777]                                                                                                 
33.6 ns ± 0.411 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)

2. Remove recursion. At all.
3. Remove typing

V
Vladimir Olohtonov, 2019-07-26
@sgjurano

Are you sure your paths are being compressed correctly?
Well, it’s better to get away from recursion, in python there are quite expensive function calls.

Didn't find what you were looking for?

Ask your question

Ask a Question

731 491 924 answers to any question