Use PyTorch as logits transpose for ONNX support (#141)

This commit is contained in:
Michael Goin 2022-09-26 13:54:26 -04:00 committed by GitHub
parent 2037b65f3f
commit 9c8183a179
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -189,7 +189,7 @@ class TextDecoder(nn.Module):
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
x = self.ln(x)
logits = (x @ self.token_embedding.weight.to(x.dtype).T).float()
logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
return logits