Move all of the forward pass to under torch.no_grad() (#3636)

This commit is contained in:
Sebastian 2022-11-28 23:59:49 -08:00 committed by GitHub
parent b20f808119
commit c7c2235874
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -493,9 +493,9 @@ class _TapasScoredEncoder(_BaseTapasEncoder):
# Forward pass through model
with torch.no_grad():
outputs = self.model.tapas(**inputs)
table_score = self.model.classifier(outputs.pooler_output)
# Get general table score
table_score = self.model.classifier(outputs.pooler_output)
table_score_softmax = torch.nn.functional.softmax(table_score, dim=1)
table_relevancy_prob = table_score_softmax[0][1].item()
no_answer_score = table_score_softmax[0][0].item()