minor bug fixes (#207)

* fixed path arg for create_dataset_csvs()

* updated assign_check() to remove user warning
This commit is contained in:
Daniel Kleine 2024-06-12 13:27:56 +02:00 committed by GitHub
parent b2ff989174
commit e5c3c5ce99
2 changed files with 5 additions and 14 deletions

View File

@ -172,7 +172,7 @@
"def assign_check(left, right):\n",
" if left.shape != right.shape:\n",
" raise ValueError(f\"Shape mismatch. Left: {left.shape}, Right: {right.shape}\")\n",
" return torch.nn.Parameter(torch.tensor(right))"
" return torch.nn.Parameter(right.clone().detach())"
]
},
{
@ -227,16 +227,7 @@
"execution_count": 7,
"id": "cda44d37-92c0-4c19-a70a-15711513afce",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_9385/3877979348.py:4: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" return torch.nn.Parameter(torch.tensor(right))\n"
]
}
],
"outputs": [],
"source": [
"import torch\n",
"from previous_chapters import GPTModel\n",
@ -250,7 +241,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 8,
"id": "4ddd0d51-3ade-4890-9bab-d63f141d095f",
"metadata": {},
"outputs": [
@ -302,7 +293,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.11"
}
},
"nbformat": 4,

View File

@ -117,7 +117,7 @@ def random_split(df, train_frac, validation_frac):
return train_df, validation_df, test_df
def create_dataset_csvs(data_file_path):
def create_dataset_csvs(new_file_path):
df = pd.read_csv(new_file_path, sep="\t", header=None, names=["Label", "Text"])
# Create balanced dataset