Raise error if torch-scatter is not installed or wrong version is installed (#2486)

* automatically download correct torch-scatter version

* raise error if torch-scatter is not installed

* Update Documentation & Code Style

* catch all import errors and fix linter

* Update Documentation & Code Style

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
MichelBartels 2022-05-05 10:12:10 +02:00 committed by GitHub
parent 1418f0c603
commit 5d98810a17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 26 additions and 2 deletions

View File

@ -36,7 +36,10 @@ Make sure you enable the GPU runtime to experience decent speed in this tutorial
!pip install git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab]
# The TaPAs-based TableReader requires the torch-scatter library
!pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
import torch
version = torch.__version__
!pip install torch-scatter -f https://data.pyg.org/whl/torch-{version}.html
# Install pygraphviz for visualization of Pipelines
!apt install libgraphviz-dev

View File

@ -21,6 +21,15 @@ from haystack.schema import Document, Answer, Span
from haystack.nodes.reader.base import BaseReader
from haystack.modeling.utils import initialize_device_settings
torch_scatter_installed = True
torch_scatter_wrong_version = False
try:
import torch_scatter # pylint: disable=unused-import
except ImportError:
torch_scatter_installed = False
except OSError:
torch_scatter_wrong_version = True
logger = logging.getLogger(__name__)
@ -95,6 +104,15 @@ class TableReader(BaseReader):
query + table exceed max_seq_len, the table will be truncated by removing rows until the
input size fits the model.
"""
if not torch_scatter_installed:
raise ImportError(
"Please install torch_scatter to use TableReader. You can follow the instructions here: https://github.com/rusty1s/pytorch_scatter"
)
if torch_scatter_wrong_version:
raise ImportError(
"torch_scatter could not be loaded. This could be caused by a mismatch between your cuda version and the one used by torch_scatter."
"Please try to reinstall torch-scatter. You can follow the instructions here: https://github.com/rusty1s/pytorch_scatter"
)
super().__init__()
self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False)

View File

@ -55,7 +55,10 @@
"!pip install git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab]\n",
"\n",
"# The TaPAs-based TableReader requires the torch-scatter library\n",
"!pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu113.html\n",
"import torch\n",
"\n",
"version = torch.__version__\n",
"!pip install torch-scatter -f https://data.pyg.org/whl/torch-{version}.html\n",
"\n",
"# Install pygraphviz for visualization of Pipelines\n",
"!apt install libgraphviz-dev\n",