unstructured/examples/sec-sentiment-analysis/sec-sentiment-analysis.ipynb
Matt Robinson d9c035edb1
docs: no more bricks (#1967)
### Summary

We no longer use the "bricks" terminology for partioning functions, etc
in the library. This PR updates various references to bricks within the
repo and the docs. This is just an initial pass to swap the terminology
out, it'll likely be helpful to reorganize the docs a bit as well.

---------

Co-authored-by: qued <64741807+qued@users.noreply.github.com>
Co-authored-by: ryannikolaidis <1208590+ryannikolaidis@users.noreply.github.com>
2023-11-02 09:43:26 -05:00

872 lines
27 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "b6d3b12f",
"metadata": {},
"source": [
"## SEC Sentiment Analysis\n",
"\n",
"In this Notebook, we'll show how to use the [`unstructured` core library](https://unstructured-io.github.io/unstructured/) and the [SEC pipelines API](https://github.com/Unstructured-IO/pipeline-sec-filings) to train a sentiment analysis model using content from the risk factors section of S-1 filings. To train and use the sentiment analysis model, we'll perform the following steps:\n",
"\n",
"1. [Grab 10-K filings from EDGAR](#get-filings)\n",
"1. [Extract the Risk Factors section using the SEC pipeline API](#extract-narrative)\n",
"1. [Use a staging function to stage the data for a labeling task in LabelStudio](#stage-label-studio)\n",
"1. [Train a sentiment analysis model with Hugging Face](#train)\n",
"1. [Use a staging function to chunk input for the attention window of the sentiment analysis model](#chunk)"
]
},
{
"cell_type": "markdown",
"id": "9dce7f11",
"metadata": {},
"source": [
"### Grab 10-K filings from EDGAR <a id=\"get-filings\"></a>\n",
"\n",
"The first step in the process is to pull documents from EDGAR, the SEC's filing system. Filings in EDGAR are in XML format and use a standard called [XBRL](https://www.xbrl.org/the-standard/what/ixbrl/). To do this, we'll make a few API calls based on the ticker symbols of publicly traded companies and save the files locally in a directory called `xbrl-forms`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8534e135",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from fetch import get_form_by_ticker"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "820c49a1",
"metadata": {},
"outputs": [],
"source": [
"tickers = [\n",
" \"ehc\",\n",
" \"mrk\",\n",
" \"nke\",\n",
" \"msex\",\n",
" \"v\",\n",
" \"cvs\",\n",
" \"doc\",\n",
" \"smtc\",\n",
" \"cl\",\n",
" \"ava\",\n",
" \"bc\",\n",
" \"f\",\n",
" \"lmt\",\n",
" \"cri\",\n",
" \"aig\",\n",
" \"rgld\",\n",
" \"apld\",\n",
" \"omcl\",\n",
" \"mmm\",\n",
" \"bgs\",\n",
" \"dis\",\n",
" \"wetg\",\n",
" \"bj\",\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f50b5e40",
"metadata": {},
"outputs": [],
"source": [
"cwd = os.getcwd()\n",
"data_directory = os.path.join(cwd, \"xbrl-forms\")\n",
"if not os.path.exists(data_directory):\n",
" os.mkdir(data_directory)"
]
},
{
"cell_type": "markdown",
"id": "b0b8e632",
"metadata": {},
"source": [
"If you're running this notebook at home, make sure to update the company and email as appropriate to correctly identify yourself to the SEC API."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9b9d7c32",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"......................."
]
}
],
"source": [
"forms = []\n",
"for ticker in tickers:\n",
" form_text = get_form_by_ticker(\n",
" ticker=ticker,\n",
" form_type=\"10-K\",\n",
" company=\"Unstructured Technologies\",\n",
" email=\"support@unstructured.io\",\n",
" )\n",
"\n",
" filename = os.path.join(data_directory, f\"{ticker}-10k.xbrl\")\n",
" with open(filename, \"w\") as f:\n",
" f.write(form_text)\n",
" print(\".\", end=\"\")"
]
},
{
"cell_type": "markdown",
"id": "ce79d4e7",
"metadata": {},
"source": [
"### Extract the Risk Factors Narrative <a id=\"extract-narrative\"></a>\n",
"\n",
"Next, we'll extract the risk factors section by submitting the documents to the Unstructured SEC pipelines API. The SEC pipelines API accepts documents in XBRL format, finds the requested section, and returns the document as a JSON. You can learn more about the SEC pipelines API [here](https://github.com/Unstructured-IO/pipeline-sec-filings)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d7727c8e",
"metadata": {},
"outputs": [],
"source": [
"import requests\n",
"import time\n",
"from fetch import get_version"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dc99e2be",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"https://api.unstructured.io/sec-filings/v0.2.0/section\n"
]
}
],
"source": [
"version = get_version()\n",
"url = f\"https://api.unstructured.io/sec-filings/v{version}/section\"\n",
"print(url)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ebd5f5bc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"......................."
]
}
],
"source": [
"risk_factors = dict()\n",
"for ticker in tickers:\n",
" response = requests.post(\n",
" url,\n",
" files={\"text_files\": open(f\"./xbrl-forms/{ticker}-10k.xbrl\", \"rb\")},\n",
" data={\"section\": [\"RISK_FACTORS\"]},\n",
" )\n",
" response.raise_for_status()\n",
" risk_factors[ticker] = response.json()[\"RISK_FACTORS\"]\n",
" time.sleep(1)\n",
" print(\".\", end=\"\")"
]
},
{
"cell_type": "markdown",
"id": "17379800",
"metadata": {},
"source": [
"### Stage for LabelStudio <a id=\"stage-label-studio\"></a>\n",
"\n",
"The next step is to label our data for the sentiment analysis model. To do that, we'll use [LabelStudio](https://labelstud.io/). The `unstructured` core library lets us easily prepare data for upload to LabelStudio using the [`stage_for_label_studio`](https://unstructured-io.github.io/unstructured/functions.html#stage-for-label-studio) staging function. In this section, we'll format the data for upload to LabelStudio, and also use an off-the-shelf sentiment analysis model to pre-annotate the data. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "73a85ca3",
"metadata": {},
"outputs": [],
"source": [
"from unstructured.staging.base import isd_to_elements"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2edab5e3",
"metadata": {},
"outputs": [],
"source": [
"elements = []\n",
"for sections in risk_factors.values():\n",
" elements.extend(isd_to_elements(sections))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0ac45df",
"metadata": {},
"outputs": [],
"source": [
"from transformers import pipeline"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "19a7d16b",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "03ab9401f78e4acaa513ea269e95cfd6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading: 0%| | 0.00/5.47k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = \"distilbert-base-uncased-finetuned-sst-2-english\"\n",
"sentiment_pipeline = pipeline(model=model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "01bfd34d",
"metadata": {},
"outputs": [],
"source": [
"from unstructured.staging.label_studio import (\n",
" stage_for_label_studio,\n",
" LabelStudioAnnotation,\n",
" LabelStudioResult,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "83bc7d6a",
"metadata": {},
"source": [
"In this step, we apply an off-the-shelf sentiment analysis model to pre-annotate our data. Once it's up in LabelStudio, you'll see the model outputs applied as the default labels. Feel free to update the labels as appropriate in the LabelStudio UI."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "30b00cdc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"..............................................................."
]
}
],
"source": [
"annotations = []\n",
"for i, element in enumerate(elements):\n",
" inference = sentiment_pipeline(element.text, truncation=True)\n",
" result = [\n",
" LabelStudioResult(\n",
" type=\"choices\",\n",
" value={\"choices\": [inference[0][\"label\"].title()]},\n",
" from_name=\"sentiment\",\n",
" to_name=\"text\",\n",
" )\n",
" ]\n",
" annotations.append([LabelStudioAnnotation(result=result)])\n",
" print(\".\", end=\"\") if i % 40 == 1 else None"
]
},
{
"cell_type": "markdown",
"id": "7db2e8ae",
"metadata": {},
"source": [
"The `stage_for_label_studio` function formats the outputs for upload to LabelStudio. After we save the results as a JSON, we can create a new project in LabelStudio and upload the training examples."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "773e7770",
"metadata": {},
"outputs": [],
"source": [
"label_studio_data = stage_for_label_studio(\n",
" elements=elements,\n",
" annotations=annotations,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ca752c3f",
"metadata": {},
"outputs": [],
"source": [
"import json"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6e281376",
"metadata": {},
"outputs": [],
"source": [
"with open(\"sec-sentiment-analysis.json\", \"w\") as f:\n",
" json.dump(label_studio_data, f, indent=4)"
]
},
{
"cell_type": "markdown",
"id": "87ed6c09",
"metadata": {},
"source": [
"#### NOTE: Transition to LabelStudio\n",
"\n",
"At this point you should upload your data set to LabelStudio using the instructions in the [LabelStudio docs](https://labelstud.io/guide/tasks.html#Import-data-from-the-Label-Studio-UI). For the sentiment analysis model, choose the \"Text Classification\" template for your project. The JSON from this notebook will include annotations already, but you can improve the model by doing some additional labeling yourself. In the next step, we'll export the labeled data for model training."
]
},
{
"cell_type": "markdown",
"id": "1cdd8ddc",
"metadata": {},
"source": [
"### Train a Sentiment Model <a id=\"train\"></a>\n",
"\n",
"After labeling the data, we're ready to train the sentiment analysis model using the Hugging Face `transformers` library. Check out the [Hugging Face documentation](https://huggingface.co/blog/sentiment-analysis-python) for more information on how to train models in `transformers`.\n",
"\n",
"The first step is to export the labeled data from LabelStudio. When you export the data, select the JSON-Min format. Once that's done, we'll convert the dictionary to a Hugging Face `Dataset` object so that it can be used in the model training pipeline."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3909505f",
"metadata": {},
"outputs": [],
"source": [
"from datasets import Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c30b6171",
"metadata": {},
"outputs": [],
"source": [
"with open(\"sec-sentiment-analysis-labeled.json\", \"r\") as f:\n",
" training_data = json.load(f)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6625e6d2",
"metadata": {},
"outputs": [],
"source": [
"datasets_data = Dataset.from_dict(\n",
" {\n",
" \"text\": [item[\"text\"] for item in training_data],\n",
" \"label\": [0 if item[\"sentiment\"] == \"Negative\" else 1 for item in training_data],\n",
" }\n",
")"
]
},
{
"cell_type": "markdown",
"id": "03778b59",
"metadata": {},
"source": [
"Next, we'll read in our base model and tokenizer. For this example, we'll fine-tune the `distilbert-base-uncased` model using our labels from LabelStudio. Check out [this list](https://huggingface.co/models?pipeline_tag=text-classification&sort=downloads) of models from Hugging Face if you want to try fine-tuning a different base model."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1d67ffdd",
"metadata": {},
"outputs": [],
"source": [
"model_name = \"distilbert-base-uncased\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b54f54cf",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias']\n",
"- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.bias', 'classifier.weight', 'pre_classifier.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"from transformers import AutoModelForSequenceClassification\n",
"\n",
"model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6aa16b03",
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "efd70c5d",
"metadata": {},
"outputs": [],
"source": [
"def preprocess_function(examples):\n",
" return tokenizer(examples[\"text\"], truncation=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "de07aa8e",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3fed372b9f25407ead2687d96fa38d93",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3 [00:00<?, ?ba/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"tokenized_train = datasets_data.map(preprocess_function, batched=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d2da3b27",
"metadata": {},
"outputs": [],
"source": [
"from transformers import DataCollatorWithPadding\n",
"\n",
"data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a3ee17fb",
"metadata": {},
"outputs": [],
"source": [
"from transformers import Trainer"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "01f47f19",
"metadata": {},
"outputs": [],
"source": [
"trainer = Trainer(\n",
" model=model,\n",
" train_dataset=tokenized_train,\n",
" tokenizer=tokenizer,\n",
" data_collator=data_collator,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "21fdc214",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The following columns in the training set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.\n",
"/Users/mrobinson/.pyenv/versions/3.8.13/envs/sentiment/lib/python3.8/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
" warnings.warn(\n",
"***** Running training *****\n",
" Num examples = 2499\n",
" Num Epochs = 3\n",
" Instantaneous batch size per device = 8\n",
" Total train batch size (w. parallel, distributed & accumulation) = 8\n",
" Gradient Accumulation steps = 1\n",
" Total optimization steps = 939\n",
"You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='939' max='939' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [939/939 18:43, Epoch 3/3]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>Step</th>\n",
" <th>Training Loss</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>500</td>\n",
" <td>0.233400</td>\n",
" </tr>\n",
" </tbody>\n",
"</table><p>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Saving model checkpoint to tmp_trainer/checkpoint-500\n",
"Configuration saved in tmp_trainer/checkpoint-500/config.json\n",
"Model weights saved in tmp_trainer/checkpoint-500/pytorch_model.bin\n",
"tokenizer config file saved in tmp_trainer/checkpoint-500/tokenizer_config.json\n",
"Special tokens file saved in tmp_trainer/checkpoint-500/special_tokens_map.json\n",
"\n",
"\n",
"Training completed. Do not forget to share your model on huggingface.co/models =)\n",
"\n",
"\n"
]
},
{
"data": {
"text/plain": [
"TrainOutput(global_step=939, training_loss=0.1685810637550232, metrics={'train_runtime': 1125.0908, 'train_samples_per_second': 6.663, 'train_steps_per_second': 0.835, 'total_flos': 489859524447516.0, 'train_loss': 0.1685810637550232, 'epoch': 3.0})"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.train()"
]
},
{
"cell_type": "markdown",
"id": "3e4540a0",
"metadata": {},
"source": [
"Now that the model is trained, we'll save it locally so we can use it for inference. Hugging Face users can also upload the model to a remote model repository."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b6ffd39e",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Saving model checkpoint to sec-sentiment-model\n",
"Configuration saved in sec-sentiment-model/config.json\n",
"Model weights saved in sec-sentiment-model/pytorch_model.bin\n",
"tokenizer config file saved in sec-sentiment-model/tokenizer_config.json\n",
"Special tokens file saved in sec-sentiment-model/special_tokens_map.json\n"
]
}
],
"source": [
"trainer.save_model(\"sec-sentiment-model\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9840a2c6",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"loading configuration file ./sec-sentiment-model/config.json\n",
"Model config DistilBertConfig {\n",
" \"_name_or_path\": \"./sec-sentiment-model\",\n",
" \"activation\": \"gelu\",\n",
" \"architectures\": [\n",
" \"DistilBertForSequenceClassification\"\n",
" ],\n",
" \"attention_dropout\": 0.1,\n",
" \"dim\": 768,\n",
" \"dropout\": 0.1,\n",
" \"hidden_dim\": 3072,\n",
" \"initializer_range\": 0.02,\n",
" \"max_position_embeddings\": 512,\n",
" \"model_type\": \"distilbert\",\n",
" \"n_heads\": 12,\n",
" \"n_layers\": 6,\n",
" \"pad_token_id\": 0,\n",
" \"problem_type\": \"single_label_classification\",\n",
" \"qa_dropout\": 0.1,\n",
" \"seq_classif_dropout\": 0.2,\n",
" \"sinusoidal_pos_embds\": false,\n",
" \"tie_weights_\": true,\n",
" \"torch_dtype\": \"float32\",\n",
" \"transformers_version\": \"4.23.1\",\n",
" \"vocab_size\": 30522\n",
"}\n",
"\n",
"loading configuration file ./sec-sentiment-model/config.json\n",
"Model config DistilBertConfig {\n",
" \"_name_or_path\": \"./sec-sentiment-model\",\n",
" \"activation\": \"gelu\",\n",
" \"architectures\": [\n",
" \"DistilBertForSequenceClassification\"\n",
" ],\n",
" \"attention_dropout\": 0.1,\n",
" \"dim\": 768,\n",
" \"dropout\": 0.1,\n",
" \"hidden_dim\": 3072,\n",
" \"initializer_range\": 0.02,\n",
" \"max_position_embeddings\": 512,\n",
" \"model_type\": \"distilbert\",\n",
" \"n_heads\": 12,\n",
" \"n_layers\": 6,\n",
" \"pad_token_id\": 0,\n",
" \"problem_type\": \"single_label_classification\",\n",
" \"qa_dropout\": 0.1,\n",
" \"seq_classif_dropout\": 0.2,\n",
" \"sinusoidal_pos_embds\": false,\n",
" \"tie_weights_\": true,\n",
" \"torch_dtype\": \"float32\",\n",
" \"transformers_version\": \"4.23.1\",\n",
" \"vocab_size\": 30522\n",
"}\n",
"\n",
"loading weights file ./sec-sentiment-model/pytorch_model.bin\n",
"All model checkpoint weights were used when initializing DistilBertForSequenceClassification.\n",
"\n",
"All the weights of DistilBertForSequenceClassification were initialized from the model checkpoint at ./sec-sentiment-model.\n",
"If your task is similar to the task the model of the checkpoint was trained on, you can already use DistilBertForSequenceClassification for predictions without further training.\n",
"loading file vocab.txt\n",
"loading file tokenizer.json\n",
"loading file added_tokens.json\n",
"loading file special_tokens_map.json\n",
"loading file tokenizer_config.json\n"
]
}
],
"source": [
"sec_sentiment_model = pipeline(\n",
" task=\"sentiment-analysis\",\n",
" model=\"./sec-sentiment-model\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "281fd669",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Our business, operations, and financial position are subject to various risks. Some of these risks are described below, and the reader should take such risks into account in evaluating Encompass Health or any investment decision involving Encompass Health. This section does not describe all risks that may be applicable to us, our industry, or our business, and it is intended only as a summary of material risk factors. More detailed information concerning other risks and uncertainties as well as those described below is contained in other sections of this annual report. Still other risks and uncertainties we have not or cannot foresee as material to us may also adversely affect us in the future. If any of the risks below or other risks or uncertainties discussed elsewhere in this annual report are actually realized, our business and financial condition, results of operations, and cash flows could be adversely affected. In the event the impact is materially adverse, the trading price of our common stock could decline.'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"elements[0].text"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0797f19b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'label': 'LABEL_0', 'score': 0.9983240962028503}]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sec_sentiment_model(elements[0].text)"
]
},
{
"cell_type": "markdown",
"id": "b66e661e",
"metadata": {},
"source": [
"### Stage for Transformers <a id=\"chunk\"></a>\n",
"\n",
"Finally, we're ready to use our trained sentiment analysis model. To help, we'll apply our [`stage_for_transformers`](https://unstructured-io.github.io/unstructured/functions.html#stage-for-transformers) function, which chunks output based on the size of the attention window. In this case, we'll take the first ten paragraphs we received back from the SEC API and chunk them into two text snippets that fit into the attention window for the sentiment analysis model."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5ea22896",
"metadata": {},
"outputs": [],
"source": [
"from unstructured.staging.huggingface import stage_for_transformers"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1e23dce2",
"metadata": {},
"outputs": [],
"source": [
"chunked_text = stage_for_transformers(elements[:10], tokenizer)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dae34f9c",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Disabling tokenizer parallelism, we're using DataLoader multithreading already\n"
]
}
],
"source": [
"results = sec_sentiment_model(chunked_text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0a211cf3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'label': 'LABEL_0', 'score': 0.9982876181602478},\n",
" {'label': 'LABEL_0', 'score': 0.9985221028327942}]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"results"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f6f1d1ce",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}