delete use_half

This commit is contained in:
shitao 2023-08-23 13:10:56 +08:00
parent d224e686ad
commit c032c33204
5 changed files with 6 additions and 8 deletions

View File

@ -37,6 +37,7 @@ torchrun --nproc_per_node {number of gpus} \
--output_dir {path to save model} \
--model_name_or_path {base model} \
--train_data {path to train data} \
--per_device_train_batch_size {batch size} \
--learning_rate 2e-5 \
--num_train_epochs 5 \
--max_seq_length 512
@ -66,7 +67,7 @@ torchrun --nproc_per_node {number of gpus} \
--train_data toy_finetune_data.jsonl \
--learning_rate 1e-5 \
--num_train_epochs 5 \
--per_device_train_batch_size {large batch size} \
--per_device_train_batch_size {batch size} \
--dataloader_drop_last True \
--normlized True \
--temperature 0.01 \

View File

@ -11,8 +11,7 @@ class FlagModel:
model_name_or_path: str = None,
pooling_method: str = 'cls',
normalize_embeddings: bool = True,
query_instruction_for_retrieval: str = None,
use_half: bool = True,
query_instruction_for_retrieval: str = None
) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
@ -21,8 +20,6 @@ class FlagModel:
self.normalize_embeddings = normalize_embeddings
self.pooling_method = pooling_method
if use_half: self.model.half()
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
self.model = self.model.to(self.device)

View File

@ -41,7 +41,7 @@ torchrun --nproc_per_node {number of gpus} \
--train_data toy_finetune_data.jsonl \
--learning_rate 1e-5 \
--num_train_epochs 5 \
--per_device_train_batch_size {large batch size} \
--per_device_train_batch_size {large batch size; set 1 for toy data} \
--dataloader_drop_last True \
--normlized True \
--temperature 0.01 \

View File

@ -37,7 +37,7 @@ torchrun --nproc_per_node {number of gpus} \
--train_data toy_pretrain_data.jsonl \
--learning_rate 2e-5 \
--num_train_epochs 2 \
--per_device_train_batch_size {batch size} \
--per_device_train_batch_size {batch size; set 1 for toy data} \
--dataloader_drop_last True \
--max_seq_length 512 \
--logging_steps 10

View File

@ -5,7 +5,7 @@ with open("README.md", mode="r", encoding="utf-8") as readme_file:
setup(
name='FlagEmbedding',
version='1.0.3',
version='1.0.4',
description='FlagEmbedding',
long_description=readme,
long_description_content_type="text/markdown",