mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-28 15:38:36 +00:00
Raise error if torch-scatter is not installed or wrong version is installed (#2486)
* automatically download correct torch-scatter version * raise error if torch-scatter is not installed * Update Documentation & Code Style * catch all import errors and fix linter * Update Documentation & Code Style Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
1418f0c603
commit
5d98810a17
@ -36,7 +36,10 @@ Make sure you enable the GPU runtime to experience decent speed in this tutorial
|
||||
!pip install git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab]
|
||||
|
||||
# The TaPAs-based TableReader requires the torch-scatter library
|
||||
!pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
|
||||
import torch
|
||||
|
||||
version = torch.__version__
|
||||
!pip install torch-scatter -f https://data.pyg.org/whl/torch-{version}.html
|
||||
|
||||
# Install pygraphviz for visualization of Pipelines
|
||||
!apt install libgraphviz-dev
|
||||
|
||||
@ -21,6 +21,15 @@ from haystack.schema import Document, Answer, Span
|
||||
from haystack.nodes.reader.base import BaseReader
|
||||
from haystack.modeling.utils import initialize_device_settings
|
||||
|
||||
torch_scatter_installed = True
|
||||
torch_scatter_wrong_version = False
|
||||
try:
|
||||
import torch_scatter # pylint: disable=unused-import
|
||||
except ImportError:
|
||||
torch_scatter_installed = False
|
||||
except OSError:
|
||||
torch_scatter_wrong_version = True
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -95,6 +104,15 @@ class TableReader(BaseReader):
|
||||
query + table exceed max_seq_len, the table will be truncated by removing rows until the
|
||||
input size fits the model.
|
||||
"""
|
||||
if not torch_scatter_installed:
|
||||
raise ImportError(
|
||||
"Please install torch_scatter to use TableReader. You can follow the instructions here: https://github.com/rusty1s/pytorch_scatter"
|
||||
)
|
||||
if torch_scatter_wrong_version:
|
||||
raise ImportError(
|
||||
"torch_scatter could not be loaded. This could be caused by a mismatch between your cuda version and the one used by torch_scatter."
|
||||
"Please try to reinstall torch-scatter. You can follow the instructions here: https://github.com/rusty1s/pytorch_scatter"
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False)
|
||||
|
||||
@ -55,7 +55,10 @@
|
||||
"!pip install git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab]\n",
|
||||
"\n",
|
||||
"# The TaPAs-based TableReader requires the torch-scatter library\n",
|
||||
"!pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu113.html\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"version = torch.__version__\n",
|
||||
"!pip install torch-scatter -f https://data.pyg.org/whl/torch-{version}.html\n",
|
||||
"\n",
|
||||
"# Install pygraphviz for visualization of Pipelines\n",
|
||||
"!apt install libgraphviz-dev\n",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user