D
D
dBegginer2020-01-21 09:48:18
Python
dBegginer, 2020-01-21 09:48:18

How to run DataParallel on pytorch?

net = dgpt.GTP2(num_layers = num_layers, d_model = d_model, num_heads = num_heads, dff = dff, vocab_size=VOCAB_SIZE, pe_target=CONTEXT_SIZE, rate=0.1)
net = torch.nn.DataParallel(net, device_ids=[0,1]).cuda()

inputs = inputs.cuda()
outputs = net(inputs)

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model

        self.depth = d_model // self.num_heads

        self.wq = nn.Linear(d_model, d_model).cuda()
        self.wk = nn.Linear(d_model, d_model).cuda()
        self.wv = nn.Linear(d_model, d_model).cuda()

        self.dense = nn.Linear(d_model, d_model).cuda()

    def split_heads(self, x, batch_size):
        x = x.view(batch_size, -1, self.num_heads, self.depth)
        return x.permute(0, 2, 1, 3)

    def forward(self, q, k, v, mask):
        batch_size = q.shape[0]

        q = self.wq(q)
        k = self.wk(k)
        v = self.wv(v)

File "/home/m/Документы/mgpu/dgpt2.py", line 117, in forward
    q = self.wq(q)
RuntimeError: arguments are located on different GPUs at /pytorch/aten/src/THC/generic/THCTensorMathBlas.cu:277

Answer the question

In order to leave comments, you need to log in

Didn't find what you were looking for?

Ask your question

Ask a Question

731 491 924 answers to any question