mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-12 08:03:50 +00:00
Increase FARM to Version 0.6.2 (#755)
* Increase farm version * Fix test
This commit is contained in:
parent
725c03220f
commit
7522d2d1b0
@ -508,7 +508,7 @@ class FARMReader(BaseReader):
|
|||||||
# Create DataLoader that can be passed to the Evaluator
|
# Create DataLoader that can be passed to the Evaluator
|
||||||
tic = perf_counter()
|
tic = perf_counter()
|
||||||
indices = range(len(farm_input))
|
indices = range(len(farm_input))
|
||||||
dataset, tensor_names = self.inferencer.processor.dataset_from_dicts(farm_input, indices=indices)
|
dataset, tensor_names, problematic_ids = self.inferencer.processor.dataset_from_dicts(farm_input, indices=indices)
|
||||||
data_loader = NamedDataLoader(dataset=dataset, batch_size=self.inferencer.batch_size, tensor_names=tensor_names)
|
data_loader = NamedDataLoader(dataset=dataset, batch_size=self.inferencer.batch_size, tensor_names=tensor_names)
|
||||||
|
|
||||||
evaluator = Evaluator(data_loader=data_loader, tasks=self.inferencer.processor.tasks, device=device)
|
evaluator = Evaluator(data_loader=data_loader, tasks=self.inferencer.processor.tasks, device=device)
|
||||||
|
|||||||
@ -179,7 +179,7 @@ class DensePassageRetriever(BaseRetriever):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
dataset, tensor_names, baskets = self.processor.dataset_from_dicts(
|
dataset, tensor_names, problematic_ids, baskets = self.processor.dataset_from_dicts(
|
||||||
dicts, indices=[i for i in range(len(dicts))], return_baskets=True
|
dicts, indices=[i for i in range(len(dicts))], return_baskets=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -2,6 +2,8 @@ import logging
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from transformers import pipeline
|
from transformers import pipeline
|
||||||
|
from transformers.models.auto.modeling_auto import AutoModelForSeq2SeqLM
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from haystack import Document
|
from haystack import Document
|
||||||
from haystack.summarizer.base import BaseSummarizer
|
from haystack.summarizer.base import BaseSummarizer
|
||||||
@ -73,7 +75,11 @@ class TransformersSummarizer(BaseSummarizer):
|
|||||||
into a single text. This separator appears between those subsequent docs.
|
into a single text. This separator appears between those subsequent docs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.summarizer = pipeline("summarization", model=model_name_or_path, tokenizer=tokenizer, device=use_gpu)
|
# TODO AutoModelForSeq2SeqLM is only necessary with transformers==4.1.1, with newer versions use the pipeline directly
|
||||||
|
if tokenizer is None:
|
||||||
|
tokenizer = model_name_or_path
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path=model_name_or_path)
|
||||||
|
self.summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device=use_gpu)
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.min_length = min_length
|
self.min_length = min_length
|
||||||
self.clean_up_tokenization_spaces = clean_up_tokenization_spaces
|
self.clean_up_tokenization_spaces = clean_up_tokenization_spaces
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
farm==0.5.0
|
farm==0.6.2
|
||||||
--find-links=https://download.pytorch.org/whl/torch_stable.html
|
--find-links=https://download.pytorch.org/whl/torch_stable.html
|
||||||
fastapi
|
fastapi
|
||||||
uvicorn
|
uvicorn
|
||||||
|
|||||||
@ -113,6 +113,6 @@ def test_dpr_saving_and_loading(retriever, document_store):
|
|||||||
assert loaded_retriever.query_tokenizer.do_lower_case == True
|
assert loaded_retriever.query_tokenizer.do_lower_case == True
|
||||||
assert loaded_retriever.passage_tokenizer.vocab_size == 30522
|
assert loaded_retriever.passage_tokenizer.vocab_size == 30522
|
||||||
assert loaded_retriever.query_tokenizer.vocab_size == 30522
|
assert loaded_retriever.query_tokenizer.vocab_size == 30522
|
||||||
assert loaded_retriever.passage_tokenizer.max_len == 512
|
assert loaded_retriever.passage_tokenizer.model_max_length == 512
|
||||||
assert loaded_retriever.query_tokenizer.max_len == 512
|
assert loaded_retriever.query_tokenizer.model_max_length == 512
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user