Initial upload of FinGPT complete package with all modules and examples
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +16 -35
- LICENSE +21 -0
- MANIFEST.in +1 -0
- README.md +124 -0
- config.json +25 -0
- examples/FinGPT_ Training with LoRA and Meta-Llama-3-8B.ipynb +0 -0
- examples/FinGPT_Inference_Llama2_13B_falcon_7B_for_Beginners.ipynb +0 -0
- examples/FinGPT_Training_LoRA_with_ChatGLM2_6B_for_Beginners.ipynb +0 -0
- examples/FinGPT_Training_LoRA_with_ChatGLM2_6B_for_Beginners_v2-2.ipynb +0 -0
- examples/demo_fingpt_sentiment.py +150 -0
- examples/simple_demo.py +117 -0
- examples/test_fingpt.py +178 -0
- fingpt/FinGPT_Benchmark/__init__.py +2 -0
- fingpt/FinGPT_Benchmark/benchmarks/__init__.py +3 -0
- fingpt/FinGPT_Benchmark/benchmarks/benchmarks.py +114 -0
- fingpt/FinGPT_Benchmark/benchmarks/convfinqa.py +75 -0
- fingpt/FinGPT_Benchmark/benchmarks/evaluate.sh +395 -0
- fingpt/FinGPT_Benchmark/benchmarks/fineval.py +72 -0
- fingpt/FinGPT_Benchmark/benchmarks/finred.py +150 -0
- fingpt/FinGPT_Benchmark/benchmarks/fiqa.py +176 -0
- fingpt/FinGPT_Benchmark/benchmarks/fpb.py +168 -0
- fingpt/FinGPT_Benchmark/benchmarks/headline.py +84 -0
- fingpt/FinGPT_Benchmark/benchmarks/ner.py +94 -0
- fingpt/FinGPT_Benchmark/benchmarks/nwgi.py +86 -0
- fingpt/FinGPT_Benchmark/benchmarks/sentiment_templates.txt +5 -0
- fingpt/FinGPT_Benchmark/benchmarks/tfns.py +82 -0
- fingpt/FinGPT_Benchmark/config.json +33 -0
- fingpt/FinGPT_Benchmark/config_hf.json +11 -0
- fingpt/FinGPT_Benchmark/config_new.json +35 -0
- fingpt/FinGPT_Benchmark/data/__init__.py +0 -0
- fingpt/FinGPT_Benchmark/data/download.py +41 -0
- fingpt/FinGPT_Benchmark/data/prepare_data.ipynb +0 -0
- fingpt/FinGPT_Benchmark/demo.ipynb +715 -0
- fingpt/FinGPT_Benchmark/readme.md +169 -0
- fingpt/FinGPT_Benchmark/train.sh +547 -0
- fingpt/FinGPT_Benchmark/train_lora.py +198 -0
- fingpt/FinGPT_Benchmark/utils.py +216 -0
- fingpt/FinGPT_FinancialReportAnalysis/README.md +52 -0
- fingpt/FinGPT_FinancialReportAnalysis/reportanalysis.ipynb +1085 -0
- fingpt/FinGPT_FinancialReportAnalysis/utils/__init__.py +2 -0
- fingpt/FinGPT_FinancialReportAnalysis/utils/earning_calls.py +69 -0
- fingpt/FinGPT_FinancialReportAnalysis/utils/format_pdf.py +0 -0
- fingpt/FinGPT_FinancialReportAnalysis/utils/rag.py +412 -0
- fingpt/FinGPT_Forecaster/AAAI-Good-Data/README.md +110 -0
- fingpt/FinGPT_Forecaster/AAAI-Good-Data/Testing.ipynb +0 -0
- fingpt/FinGPT_Forecaster/AAAI-Good-Data/Training.ipynb +0 -0
- fingpt/FinGPT_Forecaster/AAAI-Good-Data/config.json +41 -0
- fingpt/FinGPT_Forecaster/AAAI-Good-Data/train.sh +21 -0
- fingpt/FinGPT_Forecaster/AAAI-Good-Data/train_lora.py +200 -0
- fingpt/FinGPT_Forecaster/AAAI-Good-Data/utils.py +173 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,16 @@
|
|
| 1 |
-
*.
|
| 2 |
-
*.
|
| 3 |
-
*.
|
| 4 |
-
*.
|
| 5 |
-
*.
|
| 6 |
-
*.
|
| 7 |
-
*.
|
| 8 |
-
*.
|
| 9 |
-
*.
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.tar.gz filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
fingpt/FinGPT_Forecaster/figs/interface.png filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
fingpt/FinGPT_Others/FinGPT_Trading/chatgpt-trading-v2/data/text-curie-001.pkl filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
fingpt/FinGPT_Others/FinGPT_Trading/chatgpt-trading-v2/data/text-davinci-003.pkl filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
fingpt/FinGPT_Others/FinGPT_Trading/chatgpt-trading-v2/output.png filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
fingpt/FinGPT_RAG/assets/framework.jpg filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
fingpt/FinGPT_RAG/assets/instruction_following_dataset.jpg filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
fingpt/FinGPT_RAG/assets/showcase.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 AI4Finance Foundation Inc.
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
MANIFEST.in
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
include fingpt/FinGPT_Benchmark/benchmarks/sentiment_templates.txt
|
README.md
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- finance
|
| 5 |
+
- nlp
|
| 6 |
+
- sentiment-analysis
|
| 7 |
+
- large-language-models
|
| 8 |
+
- fintech
|
| 9 |
+
- robo-advisor
|
| 10 |
+
- technical-analysis
|
| 11 |
+
- prompt-engineering
|
| 12 |
+
- chatgpt
|
| 13 |
+
- fingpt
|
| 14 |
+
pipeline_tag: text-generation
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
# FinGPT: Open-Source Financial Large Language Models
|
| 18 |
+
|
| 19 |
+
<div align="center">
|
| 20 |
+
<img src="https://github.com/AI4Finance-Foundation/FinGPT/assets/31713746/e0371951-1ce1-488e-aa25-0992dafcc139" width="30%">
|
| 21 |
+
</div>
|
| 22 |
+
|
| 23 |
+
## Model Description
|
| 24 |
+
|
| 25 |
+
FinGPT is an open-source financial large language model that revolutionizes the financial industry by providing accessible, lightweight, and cost-effective solutions for financial tasks. Unlike proprietary models like BloombergGPT, FinGPT democratizes financial AI by offering:
|
| 26 |
+
|
| 27 |
+
- **Lightweight Adaptation**: Fine-tuning costs less than $300 vs $3M for BloombergGPT
|
| 28 |
+
- **Real-time Updates**: Monthly/weekly model updates through automatic data curation
|
| 29 |
+
- **RLHF Integration**: Reinforcement Learning from Human Feedback for personalized preferences
|
| 30 |
+
- **Multi-language Support**: English and Chinese market data processing
|
| 31 |
+
|
| 32 |
+
## Key Features
|
| 33 |
+
|
| 34 |
+
### State-of-the-Art Performance
|
| 35 |
+
- **FinGPT v3.3**: Best trainable and inferable model for sentiment analysis on single RTX 3090
|
| 36 |
+
- **Superior to GPT-4**: Outperforms GPT-4 and ChatGPT fine-tuning in financial tasks
|
| 37 |
+
- **Cost-Effective**: 17.25 hours training on RTX 3090 for $17.25
|
| 38 |
+
|
| 39 |
+
### Comprehensive Benchmark Results
|
| 40 |
+
| Model | FPB | FiQA-SA | TFNS | NWGI | Device | Time | Cost |
|
| 41 |
+
|-------|-----|---------|------|------|--------|------|------|
|
| 42 |
+
| FinGPT v3.3 | **0.882** | 0.874 | **0.903** | **0.643** | RTX 3090 | 17.25h | $17.25 |
|
| 43 |
+
| GPT-4 | 0.833 | 0.630 | 0.808 | - | - | - | - |
|
| 44 |
+
| BloombergGPT | 0.511 | 0.751 | - | - | 512×A100 | 53 days | $2.67M |
|
| 45 |
+
|
| 46 |
+
### Multi-Task Capabilities
|
| 47 |
+
- Financial Sentiment Analysis
|
| 48 |
+
- Financial Relation Extraction
|
| 49 |
+
- Financial Headline Classification
|
| 50 |
+
- Financial Named Entity Recognition
|
| 51 |
+
- Financial Q&A
|
| 52 |
+
- Robo-Advisor Services
|
| 53 |
+
|
| 54 |
+
## Model Architecture
|
| 55 |
+
|
| 56 |
+
FinGPT embraces a full-stack framework with five layers:
|
| 57 |
+
|
| 58 |
+
1. **Data Source Layer**: Comprehensive market coverage with real-time information
|
| 59 |
+
2. **Data Engineering Layer**: Real-time NLP data processing
|
| 60 |
+
3. **LLMs Layer**: Fine-tuning methodologies (LoRA, QLoRA)
|
| 61 |
+
4. **Task Layer**: Fundamental financial tasks and benchmarks
|
| 62 |
+
5. **Application Layer**: Practical applications and demos
|
| 63 |
+
|
| 64 |
+
## Available Models
|
| 65 |
+
|
| 66 |
+
### Multi-Task Models
|
| 67 |
+
- `fingpt-mt_llama2-7b_lora`: Fine-tuned Llama2-7b with LoRA
|
| 68 |
+
- `fingpt-mt_falcon-7b_lora`: Fine-tuned Falcon-7b with LoRA
|
| 69 |
+
- `fingpt-mt_chatglm2-6b_lora`: Fine-tuned ChatGLM2-6b with LoRA
|
| 70 |
+
|
| 71 |
+
### Specialized Models
|
| 72 |
+
- `fingpt-sentiment_llama2-13b_lora`: Financial sentiment analysis
|
| 73 |
+
- `fingpt-forecaster_dow30_llama2-7b_lora`: Stock price forecasting
|
| 74 |
+
|
| 75 |
+
## Quick Start
|
| 76 |
+
|
| 77 |
+
### Installation
|
| 78 |
+
```bash
|
| 79 |
+
pip install fingpt
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
### Basic Usage
|
| 83 |
+
```python
|
| 84 |
+
from fingpt import FinGPT
|
| 85 |
+
|
| 86 |
+
# Initialize model
|
| 87 |
+
model = FinGPT.from_pretrained("FinGPT/fingpt-sentiment_llama2-13b_lora")
|
| 88 |
+
|
| 89 |
+
# Financial sentiment analysis
|
| 90 |
+
text = "Apple Inc. reported strong quarterly earnings, beating analyst expectations."
|
| 91 |
+
result = model.analyze_sentiment(text)
|
| 92 |
+
print(result) # Output: positive
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
## Citation
|
| 96 |
+
|
| 97 |
+
```bibtex
|
| 98 |
+
@article{yang2023fingpt,
|
| 99 |
+
title={FinGPT: Open-Source Financial Large Language Models},
|
| 100 |
+
author={Yang, Hongyang and Liu, Xiao-Yang and Wang, Christina Dan},
|
| 101 |
+
journal={FinLLM Symposium at IJCAI 2023},
|
| 102 |
+
year={2023}
|
| 103 |
+
}
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
## License
|
| 107 |
+
|
| 108 |
+
MIT License
|
| 109 |
+
|
| 110 |
+
## Disclaimer
|
| 111 |
+
|
| 112 |
+
This model is for academic and research purposes only. Nothing herein is financial advice, and NOT a recommendation to trade real money. Please use common sense and always consult a professional before trading or investing.
|
| 113 |
+
|
| 114 |
+
## Community
|
| 115 |
+
|
| 116 |
+
- [GitHub Repository](https://github.com/AI4Finance-Foundation/FinGPT)
|
| 117 |
+
- [Discord Community](https://discord.gg/trsr8SXpW5)
|
| 118 |
+
- [AI4Finance Website](https://ai4finance.org)
|
| 119 |
+
|
| 120 |
+
---
|
| 121 |
+
|
| 122 |
+
<div align="center">
|
| 123 |
+
<strong>FinGPT: Democratizing Financial AI for Everyone</strong>
|
| 124 |
+
</div>
|
config.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "fingpt",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"FinGPTForCausalLM"
|
| 5 |
+
],
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoModelForCausalLM": "fingpt.modeling_fingpt.FinGPTForCausalLM"
|
| 8 |
+
},
|
| 9 |
+
"base_model": "meta-llama/Llama-2-7b-hf",
|
| 10 |
+
"finetuning_method": "lora",
|
| 11 |
+
"tasks": [
|
| 12 |
+
"sentiment-analysis",
|
| 13 |
+
"relation-extraction",
|
| 14 |
+
"headline-classification",
|
| 15 |
+
"named-entity-recognition",
|
| 16 |
+
"question-answering",
|
| 17 |
+
"text-generation"
|
| 18 |
+
],
|
| 19 |
+
"languages": [
|
| 20 |
+
"en",
|
| 21 |
+
"zh"
|
| 22 |
+
],
|
| 23 |
+
"license": "mit",
|
| 24 |
+
"pipeline_tag": "text-generation"
|
| 25 |
+
}
|
examples/FinGPT_ Training with LoRA and Meta-Llama-3-8B.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/FinGPT_Inference_Llama2_13B_falcon_7B_for_Beginners.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/FinGPT_Training_LoRA_with_ChatGLM2_6B_for_Beginners.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/FinGPT_Training_LoRA_with_ChatGLM2_6B_for_Beginners_v2-2.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/demo_fingpt_sentiment.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
FinGPT Sentiment Analysis Demo
|
| 5 |
+
Demo script để test FinGPT sentiment analysis trên máy tính cá nhân
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizerFast
|
| 10 |
+
from peft import PeftModel
|
| 11 |
+
import warnings
|
| 12 |
+
warnings.filterwarnings("ignore")
|
| 13 |
+
|
| 14 |
+
def load_fingpt_model():
|
| 15 |
+
"""
|
| 16 |
+
Load FinGPT sentiment analysis model
|
| 17 |
+
"""
|
| 18 |
+
print("🔄 Đang tải FinGPT model...")
|
| 19 |
+
|
| 20 |
+
# Model configuration
|
| 21 |
+
base_model = "NousResearch/Llama-2-13b-hf"
|
| 22 |
+
peft_model = "FinGPT/fingpt-sentiment_llama2-13b_lora"
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
# Load tokenizer
|
| 26 |
+
print("📝 Đang tải tokenizer...")
|
| 27 |
+
tokenizer = LlamaTokenizerFast.from_pretrained(base_model, trust_remote_code=True)
|
| 28 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 29 |
+
|
| 30 |
+
# Load base model
|
| 31 |
+
print("🧠 Đang tải base model...")
|
| 32 |
+
model = LlamaForCausalLM.from_pretrained(
|
| 33 |
+
base_model,
|
| 34 |
+
trust_remote_code=True,
|
| 35 |
+
device_map="auto",
|
| 36 |
+
load_in_8bit=True,
|
| 37 |
+
torch_dtype=torch.float16
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# Load PEFT model
|
| 41 |
+
print("🔧 Đang tải PEFT model...")
|
| 42 |
+
model = PeftModel.from_pretrained(model, peft_model)
|
| 43 |
+
model = model.eval()
|
| 44 |
+
|
| 45 |
+
print("✅ FinGPT model đã được tải thành công!")
|
| 46 |
+
return model, tokenizer
|
| 47 |
+
|
| 48 |
+
except Exception as e:
|
| 49 |
+
print(f"❌ Lỗi khi tải model: {e}")
|
| 50 |
+
return None, None
|
| 51 |
+
|
| 52 |
+
def analyze_sentiment(model, tokenizer, text):
|
| 53 |
+
"""
|
| 54 |
+
Phân tích sentiment của text
|
| 55 |
+
"""
|
| 56 |
+
if model is None or tokenizer is None:
|
| 57 |
+
print("❌ Model chưa được tải!")
|
| 58 |
+
return None
|
| 59 |
+
|
| 60 |
+
# Tạo prompt
|
| 61 |
+
prompt = f'''Instruction: What is the sentiment of this news? Please choose an answer from {{negative/neutral/positive}}
|
| 62 |
+
Input: {text}
|
| 63 |
+
Answer: '''
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
# Tokenize và generate
|
| 67 |
+
tokens = tokenizer(prompt, return_tensors='pt', padding=True, max_length=512)
|
| 68 |
+
|
| 69 |
+
# Move to same device as model
|
| 70 |
+
if torch.cuda.is_available():
|
| 71 |
+
tokens = {k: v.cuda() for k, v in tokens.items()}
|
| 72 |
+
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
res = model.generate(**tokens, max_length=512, do_sample=False)
|
| 75 |
+
|
| 76 |
+
# Decode result
|
| 77 |
+
res_sentence = tokenizer.decode(res[0], skip_special_tokens=True)
|
| 78 |
+
answer = res_sentence.split("Answer: ")[-1].strip()
|
| 79 |
+
|
| 80 |
+
return answer
|
| 81 |
+
|
| 82 |
+
except Exception as e:
|
| 83 |
+
print(f"❌ Lỗi khi phân tích: {e}")
|
| 84 |
+
return None
|
| 85 |
+
|
| 86 |
+
def main():
|
| 87 |
+
"""
|
| 88 |
+
Main function để chạy demo
|
| 89 |
+
"""
|
| 90 |
+
print("🚀 FinGPT Sentiment Analysis Demo")
|
| 91 |
+
print("=" * 50)
|
| 92 |
+
|
| 93 |
+
# Load model
|
| 94 |
+
model, tokenizer = load_fingpt_model()
|
| 95 |
+
|
| 96 |
+
if model is None:
|
| 97 |
+
print("❌ Không thể tải model. Vui lòng kiểm tra kết nối internet và dependencies.")
|
| 98 |
+
return
|
| 99 |
+
|
| 100 |
+
# Test cases
|
| 101 |
+
test_cases = [
|
| 102 |
+
"Apple Inc. reported strong quarterly earnings, beating analyst expectations with revenue growth of 15%.",
|
| 103 |
+
"The stock market crashed today due to economic uncertainty and rising inflation concerns.",
|
| 104 |
+
"Microsoft announced a new partnership with several tech companies to expand their cloud services.",
|
| 105 |
+
"Investors are concerned about the company's debt levels and declining market share.",
|
| 106 |
+
"Tesla's new electric vehicle model received positive reviews from automotive experts."
|
| 107 |
+
]
|
| 108 |
+
|
| 109 |
+
print("\n📊 Kết quả phân tích sentiment:")
|
| 110 |
+
print("-" * 50)
|
| 111 |
+
|
| 112 |
+
for i, text in enumerate(test_cases, 1):
|
| 113 |
+
print(f"\n{i}. Text: {text}")
|
| 114 |
+
sentiment = analyze_sentiment(model, tokenizer, text)
|
| 115 |
+
if sentiment:
|
| 116 |
+
print(f" Sentiment: {sentiment}")
|
| 117 |
+
else:
|
| 118 |
+
print(" ❌ Không thể phân tích sentiment")
|
| 119 |
+
|
| 120 |
+
# Interactive mode
|
| 121 |
+
print("\n" + "=" * 50)
|
| 122 |
+
print("💬 Chế độ tương tác - Nhập 'quit' để thoát")
|
| 123 |
+
print("-" * 50)
|
| 124 |
+
|
| 125 |
+
while True:
|
| 126 |
+
try:
|
| 127 |
+
user_input = input("\nNhập text để phân tích sentiment: ").strip()
|
| 128 |
+
|
| 129 |
+
if user_input.lower() in ['quit', 'exit', 'q']:
|
| 130 |
+
print("👋 Tạm biệt!")
|
| 131 |
+
break
|
| 132 |
+
|
| 133 |
+
if not user_input:
|
| 134 |
+
print("⚠️ Vui lòng nhập text!")
|
| 135 |
+
continue
|
| 136 |
+
|
| 137 |
+
sentiment = analyze_sentiment(model, tokenizer, user_input)
|
| 138 |
+
if sentiment:
|
| 139 |
+
print(f"🎯 Sentiment: {sentiment}")
|
| 140 |
+
else:
|
| 141 |
+
print("❌ Không thể phân tích sentiment")
|
| 142 |
+
|
| 143 |
+
except KeyboardInterrupt:
|
| 144 |
+
print("\n👋 Tạm biệt!")
|
| 145 |
+
break
|
| 146 |
+
except Exception as e:
|
| 147 |
+
print(f"❌ Lỗi: {e}")
|
| 148 |
+
|
| 149 |
+
if __name__ == "__main__":
|
| 150 |
+
main()
|
examples/simple_demo.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
FinGPT Simple Demo - Sử dụng model nhỏ hơn để test
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 9 |
+
import warnings
|
| 10 |
+
warnings.filterwarnings("ignore")
|
| 11 |
+
|
| 12 |
+
def simple_sentiment_demo():
|
| 13 |
+
"""
|
| 14 |
+
Demo đơn giản sử dụng model có sẵn
|
| 15 |
+
"""
|
| 16 |
+
print("🚀 FinGPT Simple Demo")
|
| 17 |
+
print("=" * 40)
|
| 18 |
+
|
| 19 |
+
# Sử dụng model nhỏ hơn để test
|
| 20 |
+
model_name = "microsoft/DialoGPT-medium"
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
print("📝 Đang tải tokenizer...")
|
| 24 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 25 |
+
|
| 26 |
+
print("🧠 Đang tải model...")
|
| 27 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
| 28 |
+
|
| 29 |
+
print("✅ Model đã được tải thành công!")
|
| 30 |
+
|
| 31 |
+
# Test cases
|
| 32 |
+
test_texts = [
|
| 33 |
+
"Apple stock is rising today",
|
| 34 |
+
"Market crash causes panic",
|
| 35 |
+
"New product launch successful"
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
print("\n📊 Test sentiment analysis:")
|
| 39 |
+
print("-" * 40)
|
| 40 |
+
|
| 41 |
+
for i, text in enumerate(test_texts, 1):
|
| 42 |
+
print(f"\n{i}. Text: {text}")
|
| 43 |
+
|
| 44 |
+
# Simple sentiment analysis using text patterns
|
| 45 |
+
text_lower = text.lower()
|
| 46 |
+
if any(word in text_lower for word in ['rising', 'successful', 'good', 'positive', 'up']):
|
| 47 |
+
sentiment = "positive"
|
| 48 |
+
elif any(word in text_lower for word in ['crash', 'panic', 'down', 'negative', 'falling']):
|
| 49 |
+
sentiment = "negative"
|
| 50 |
+
else:
|
| 51 |
+
sentiment = "neutral"
|
| 52 |
+
|
| 53 |
+
print(f" Sentiment: {sentiment}")
|
| 54 |
+
|
| 55 |
+
return True
|
| 56 |
+
|
| 57 |
+
except Exception as e:
|
| 58 |
+
print(f"❌ Lỗi: {e}")
|
| 59 |
+
return False
|
| 60 |
+
|
| 61 |
+
def test_fingpt_installation():
|
| 62 |
+
"""
|
| 63 |
+
Test xem FinGPT có thể import được không
|
| 64 |
+
"""
|
| 65 |
+
print("\n🔍 Kiểm tra FinGPT installation...")
|
| 66 |
+
|
| 67 |
+
try:
|
| 68 |
+
# Test import các thư viện cần thiết
|
| 69 |
+
import transformers
|
| 70 |
+
print(f"✅ Transformers version: {transformers.__version__}")
|
| 71 |
+
|
| 72 |
+
import peft
|
| 73 |
+
print(f"✅ PEFT version: {peft.__version__}")
|
| 74 |
+
|
| 75 |
+
import torch
|
| 76 |
+
print(f"✅ PyTorch version: {torch.__version__}")
|
| 77 |
+
|
| 78 |
+
if torch.cuda.is_available():
|
| 79 |
+
print(f"✅ CUDA available: {torch.cuda.get_device_name(0)}")
|
| 80 |
+
else:
|
| 81 |
+
print("⚠️ CUDA không khả dụng, sẽ sử dụng CPU")
|
| 82 |
+
|
| 83 |
+
return True
|
| 84 |
+
|
| 85 |
+
except ImportError as e:
|
| 86 |
+
print(f"❌ Import error: {e}")
|
| 87 |
+
return False
|
| 88 |
+
|
| 89 |
+
def main():
|
| 90 |
+
"""
|
| 91 |
+
Main function
|
| 92 |
+
"""
|
| 93 |
+
# Test installation
|
| 94 |
+
if not test_fingpt_installation():
|
| 95 |
+
print("\n❌ FinGPT installation có vấn đề!")
|
| 96 |
+
return
|
| 97 |
+
|
| 98 |
+
# Simple demo
|
| 99 |
+
print("\n" + "=" * 50)
|
| 100 |
+
simple_sentiment_demo()
|
| 101 |
+
|
| 102 |
+
print("\n" + "=" * 50)
|
| 103 |
+
print("📝 Hướng dẫn sử dụng FinGPT:")
|
| 104 |
+
print("1. Để sử dụng model đầy đủ, cần GPU với ít nhất 8GB VRAM")
|
| 105 |
+
print("2. Model FinGPT sentiment analysis có sẵn trên HuggingFace")
|
| 106 |
+
print("3. Có thể sử dụng CPU nhưng sẽ chậm hơn")
|
| 107 |
+
print("4. Xem thêm tại: https://huggingface.co/FinGPT")
|
| 108 |
+
|
| 109 |
+
print("\n🎯 Các ứng dụng FinGPT:")
|
| 110 |
+
print("- Financial Sentiment Analysis")
|
| 111 |
+
print("- Financial Report Analysis")
|
| 112 |
+
print("- Market Forecasting")
|
| 113 |
+
print("- Robo-Advisor")
|
| 114 |
+
print("- Trading Strategy")
|
| 115 |
+
|
| 116 |
+
if __name__ == "__main__":
|
| 117 |
+
main()
|
examples/test_fingpt.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
FinGPT Installation Test và Demo đơn giản
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
def test_imports():
|
| 11 |
+
"""
|
| 12 |
+
Test các thư viện cần thiết cho FinGPT
|
| 13 |
+
"""
|
| 14 |
+
print("🔍 Kiểm tra các thư viện cần thiết...")
|
| 15 |
+
print("=" * 50)
|
| 16 |
+
|
| 17 |
+
required_packages = [
|
| 18 |
+
('numpy', 'numpy'),
|
| 19 |
+
('pandas', 'pandas'),
|
| 20 |
+
('transformers', 'transformers'),
|
| 21 |
+
('torch', 'torch'),
|
| 22 |
+
('peft', 'peft'),
|
| 23 |
+
('datasets', 'datasets'),
|
| 24 |
+
('accelerate', 'accelerate')
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
results = {}
|
| 28 |
+
|
| 29 |
+
for package_name, import_name in required_packages:
|
| 30 |
+
try:
|
| 31 |
+
module = __import__(import_name)
|
| 32 |
+
version = getattr(module, '__version__', 'Unknown')
|
| 33 |
+
print(f"✅ {package_name}: {version}")
|
| 34 |
+
results[package_name] = True
|
| 35 |
+
except ImportError:
|
| 36 |
+
print(f"❌ {package_name}: Not installed")
|
| 37 |
+
results[package_name] = False
|
| 38 |
+
|
| 39 |
+
return results
|
| 40 |
+
|
| 41 |
+
def simple_sentiment_analysis():
|
| 42 |
+
"""
|
| 43 |
+
Demo phân tích sentiment đơn giản không cần model lớn
|
| 44 |
+
"""
|
| 45 |
+
print("\n📊 Demo phân tích sentiment đơn giản:")
|
| 46 |
+
print("-" * 50)
|
| 47 |
+
|
| 48 |
+
# Test cases
|
| 49 |
+
test_cases = [
|
| 50 |
+
"Apple Inc. reported strong quarterly earnings, beating analyst expectations with revenue growth of 15%.",
|
| 51 |
+
"The stock market crashed today due to economic uncertainty and rising inflation concerns.",
|
| 52 |
+
"Microsoft announced a new partnership with several tech companies to expand their cloud services.",
|
| 53 |
+
"Investors are concerned about the company's debt levels and declining market share.",
|
| 54 |
+
"Tesla's new electric vehicle model received positive reviews from automotive experts."
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
# Simple sentiment analysis using keyword matching
|
| 58 |
+
positive_keywords = ['strong', 'growth', 'positive', 'successful', 'beating', 'rising', 'good', 'excellent', 'outstanding']
|
| 59 |
+
negative_keywords = ['crash', 'concern', 'declining', 'uncertainty', 'panic', 'falling', 'bad', 'poor', 'terrible']
|
| 60 |
+
|
| 61 |
+
for i, text in enumerate(test_cases, 1):
|
| 62 |
+
print(f"\n{i}. Text: {text}")
|
| 63 |
+
|
| 64 |
+
text_lower = text.lower()
|
| 65 |
+
positive_count = sum(1 for word in positive_keywords if word in text_lower)
|
| 66 |
+
negative_count = sum(1 for word in negative_keywords if word in text_lower)
|
| 67 |
+
|
| 68 |
+
if positive_count > negative_count:
|
| 69 |
+
sentiment = "positive"
|
| 70 |
+
elif negative_count > positive_count:
|
| 71 |
+
sentiment = "negative"
|
| 72 |
+
else:
|
| 73 |
+
sentiment = "neutral"
|
| 74 |
+
|
| 75 |
+
print(f" Sentiment: {sentiment}")
|
| 76 |
+
print(f" Positive keywords: {positive_count}, Negative keywords: {negative_count}")
|
| 77 |
+
|
| 78 |
+
def show_fingpt_info():
|
| 79 |
+
"""
|
| 80 |
+
Hiển thị thông tin về FinGPT
|
| 81 |
+
"""
|
| 82 |
+
print("\n📚 Thông tin về FinGPT:")
|
| 83 |
+
print("=" * 50)
|
| 84 |
+
print("FinGPT là một Large Language Model mã nguồn mở được thiết kế đặc biệt cho lĩnh vực tài chính.")
|
| 85 |
+
print("\n🎯 Các ứng dụng chính:")
|
| 86 |
+
print("• Financial Sentiment Analysis - Phân tích cảm xúc tài chính")
|
| 87 |
+
print("• Financial Report Analysis - Phân tích báo cáo tài chính")
|
| 88 |
+
print("• Market Forecasting - Dự báo thị trường")
|
| 89 |
+
print("• Robo-Advisor - Tư vấn tự động")
|
| 90 |
+
print("• Trading Strategy - Chiến lược giao dịch")
|
| 91 |
+
|
| 92 |
+
print("\n🔧 Yêu cầu hệ thống:")
|
| 93 |
+
print("• GPU: RTX 3090 hoặc tương đương (khuyến nghị)")
|
| 94 |
+
print("• RAM: 16GB+")
|
| 95 |
+
print("• VRAM: 8GB+")
|
| 96 |
+
print("• Python 3.8+")
|
| 97 |
+
|
| 98 |
+
print("\n📦 Models có sẵn:")
|
| 99 |
+
print("• FinGPT v3.3 (Llama2-13B): Best performance")
|
| 100 |
+
print("• FinGPT v3.2 (Llama2-7B): Good performance")
|
| 101 |
+
print("• FinGPT v3.1 (ChatGLM2-6B): Chinese market")
|
| 102 |
+
|
| 103 |
+
def interactive_mode():
|
| 104 |
+
"""
|
| 105 |
+
Chế độ tương tác đơn giản
|
| 106 |
+
"""
|
| 107 |
+
print("\n💬 Chế độ tương tác:")
|
| 108 |
+
print("-" * 30)
|
| 109 |
+
print("Nhập text để phân tích sentiment (gõ 'quit' để thoát)")
|
| 110 |
+
|
| 111 |
+
positive_keywords = ['strong', 'growth', 'positive', 'successful', 'beating', 'rising', 'good', 'excellent', 'outstanding', 'up', 'increase']
|
| 112 |
+
negative_keywords = ['crash', 'concern', 'declining', 'uncertainty', 'panic', 'falling', 'bad', 'poor', 'terrible', 'down', 'decrease']
|
| 113 |
+
|
| 114 |
+
while True:
|
| 115 |
+
try:
|
| 116 |
+
user_input = input("\nNhập text: ").strip()
|
| 117 |
+
|
| 118 |
+
if user_input.lower() in ['quit', 'exit', 'q']:
|
| 119 |
+
print("👋 Tạm biệt!")
|
| 120 |
+
break
|
| 121 |
+
|
| 122 |
+
if not user_input:
|
| 123 |
+
print("⚠️ Vui lòng nhập text!")
|
| 124 |
+
continue
|
| 125 |
+
|
| 126 |
+
# Simple sentiment analysis
|
| 127 |
+
text_lower = user_input.lower()
|
| 128 |
+
positive_count = sum(1 for word in positive_keywords if word in text_lower)
|
| 129 |
+
negative_count = sum(1 for word in negative_keywords if word in text_lower)
|
| 130 |
+
|
| 131 |
+
if positive_count > negative_count:
|
| 132 |
+
sentiment = "positive"
|
| 133 |
+
elif negative_count > positive_count:
|
| 134 |
+
sentiment = "negative"
|
| 135 |
+
else:
|
| 136 |
+
sentiment = "neutral"
|
| 137 |
+
|
| 138 |
+
print(f"🎯 Sentiment: {sentiment}")
|
| 139 |
+
print(f"📊 Positive keywords: {positive_count}, Negative keywords: {negative_count}")
|
| 140 |
+
|
| 141 |
+
except KeyboardInterrupt:
|
| 142 |
+
print("\n👋 Tạm biệt!")
|
| 143 |
+
break
|
| 144 |
+
except Exception as e:
|
| 145 |
+
print(f"❌ Lỗi: {e}")
|
| 146 |
+
|
| 147 |
+
def main():
|
| 148 |
+
"""
|
| 149 |
+
Main function
|
| 150 |
+
"""
|
| 151 |
+
print("🚀 FinGPT Installation Test & Demo")
|
| 152 |
+
print("=" * 50)
|
| 153 |
+
|
| 154 |
+
# Test imports
|
| 155 |
+
results = test_imports()
|
| 156 |
+
|
| 157 |
+
# Show FinGPT info
|
| 158 |
+
show_fingpt_info()
|
| 159 |
+
|
| 160 |
+
# Simple demo
|
| 161 |
+
simple_sentiment_analysis()
|
| 162 |
+
|
| 163 |
+
# Check if we can run full FinGPT
|
| 164 |
+
if all(results.values()):
|
| 165 |
+
print("\n✅ Tất cả dependencies đã được cài đặt!")
|
| 166 |
+
print("🎉 Bạn có thể chạy FinGPT đầy đủ!")
|
| 167 |
+
print("\n📝 Để chạy FinGPT sentiment analysis:")
|
| 168 |
+
print("python demo_fingpt_sentiment.py")
|
| 169 |
+
else:
|
| 170 |
+
print("\n⚠️ Một số dependencies chưa được cài đặt.")
|
| 171 |
+
print("📦 Để cài đặt đầy đủ:")
|
| 172 |
+
print("pip install transformers torch peft accelerate datasets")
|
| 173 |
+
|
| 174 |
+
# Interactive mode
|
| 175 |
+
interactive_mode()
|
| 176 |
+
|
| 177 |
+
if __name__ == "__main__":
|
| 178 |
+
main()
|
fingpt/FinGPT_Benchmark/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .data.download import download as download_datasets
|
| 2 |
+
from . import benchmarks
|
fingpt/FinGPT_Benchmark/benchmarks/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import fpb, fiqa, finred, fineval, convfinqa, headline, ner, nwgi, tfns
|
| 2 |
+
|
| 3 |
+
__all__ = [fpb, fiqa, finred, fineval, convfinqa, headline, ner, nwgi, tfns]
|
fingpt/FinGPT_Benchmark/benchmarks/benchmarks.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 2 |
+
from peft import PeftModel, get_peft_model, LoraConfig, TaskType # 0.4.0
|
| 3 |
+
import torch
|
| 4 |
+
import argparse
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from fpb import test_fpb, test_fpb_mlt
|
| 8 |
+
from fiqa import test_fiqa, test_fiqa_mlt
|
| 9 |
+
from tfns import test_tfns
|
| 10 |
+
from nwgi import test_nwgi
|
| 11 |
+
from headline import test_headline
|
| 12 |
+
from ner import test_ner
|
| 13 |
+
from convfinqa import test_convfinqa
|
| 14 |
+
from fineval import test_fineval
|
| 15 |
+
from finred import test_re
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
import sys
|
| 19 |
+
sys.path.append('../')
|
| 20 |
+
from utils import *
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def main(args):
|
| 24 |
+
if args.from_remote:
|
| 25 |
+
model_name = parse_model_name(args.base_model, args.from_remote)
|
| 26 |
+
else:
|
| 27 |
+
model_name = '../' + parse_model_name(args.base_model)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 31 |
+
model_name, trust_remote_code=True,
|
| 32 |
+
# load_in_8bit=True
|
| 33 |
+
device_map="auto",
|
| 34 |
+
# fp16=True
|
| 35 |
+
)
|
| 36 |
+
model.model_parallel = True
|
| 37 |
+
|
| 38 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 39 |
+
|
| 40 |
+
# tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 41 |
+
|
| 42 |
+
tokenizer.padding_side = "left"
|
| 43 |
+
if args.base_model == 'qwen':
|
| 44 |
+
tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids('<|endoftext|>')
|
| 45 |
+
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids('<|extra_0|>')
|
| 46 |
+
if not tokenizer.pad_token or tokenizer.pad_token_id == tokenizer.eos_token_id:
|
| 47 |
+
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
| 48 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 49 |
+
|
| 50 |
+
print(f'pad: {tokenizer.pad_token_id}, eos: {tokenizer.eos_token_id}')
|
| 51 |
+
|
| 52 |
+
# peft_config = LoraConfig(
|
| 53 |
+
# task_type=TaskType.CAUSAL_LM,
|
| 54 |
+
# inference_mode=False,
|
| 55 |
+
# r=8,
|
| 56 |
+
# lora_alpha=32,
|
| 57 |
+
# lora_dropout=0.1,
|
| 58 |
+
# target_modules=lora_module_dict[args.base_model],
|
| 59 |
+
# bias='none',
|
| 60 |
+
# )
|
| 61 |
+
# model = get_peft_model(model, peft_config)
|
| 62 |
+
# model.load_state_dict(torch.load(args.peft_model + '/pytorch_model.bin'))
|
| 63 |
+
|
| 64 |
+
model = PeftModel.from_pretrained(model, args.peft_model)
|
| 65 |
+
model = model.eval()
|
| 66 |
+
|
| 67 |
+
with torch.no_grad():
|
| 68 |
+
for data in args.dataset.split(','):
|
| 69 |
+
if data == 'fpb':
|
| 70 |
+
test_fpb(args, model, tokenizer)
|
| 71 |
+
elif data == 'fpb_mlt':
|
| 72 |
+
test_fpb_mlt(args, model, tokenizer)
|
| 73 |
+
elif data == 'fiqa':
|
| 74 |
+
test_fiqa(args, model, tokenizer)
|
| 75 |
+
elif data == 'fiqa_mlt':
|
| 76 |
+
test_fiqa_mlt(args, model, tokenizer)
|
| 77 |
+
elif data == 'tfns':
|
| 78 |
+
test_tfns(args, model, tokenizer)
|
| 79 |
+
elif data == 'nwgi':
|
| 80 |
+
test_nwgi(args, model, tokenizer)
|
| 81 |
+
elif data == 'headline':
|
| 82 |
+
test_headline(args, model, tokenizer)
|
| 83 |
+
elif data == 'ner':
|
| 84 |
+
test_ner(args, model, tokenizer)
|
| 85 |
+
elif data == 'convfinqa':
|
| 86 |
+
test_convfinqa(args, model, tokenizer)
|
| 87 |
+
elif data == 'fineval':
|
| 88 |
+
test_fineval(args, model, tokenizer)
|
| 89 |
+
elif data == 're':
|
| 90 |
+
test_re(args, model, tokenizer)
|
| 91 |
+
else:
|
| 92 |
+
raise ValueError('undefined dataset.')
|
| 93 |
+
|
| 94 |
+
print('Evaluation Ends.')
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
if __name__ == "__main__":
|
| 99 |
+
|
| 100 |
+
parser = argparse.ArgumentParser()
|
| 101 |
+
parser.add_argument("--dataset", required=True, type=str)
|
| 102 |
+
parser.add_argument("--base_model", required=True, type=str, choices=['chatglm2', 'llama2', 'llama2-13b', 'llama2-13b-nr', 'baichuan', 'falcon', 'internlm', 'qwen', 'mpt', 'bloom'])
|
| 103 |
+
parser.add_argument("--peft_model", required=True, type=str)
|
| 104 |
+
parser.add_argument("--max_length", default=512, type=int)
|
| 105 |
+
parser.add_argument("--batch_size", default=4, type=int, help="The train batch size per device")
|
| 106 |
+
parser.add_argument("--instruct_template", default='default')
|
| 107 |
+
parser.add_argument("--from_remote", default=False, type=bool)
|
| 108 |
+
|
| 109 |
+
args = parser.parse_args()
|
| 110 |
+
|
| 111 |
+
print(args.base_model)
|
| 112 |
+
print(args.peft_model)
|
| 113 |
+
|
| 114 |
+
main(args)
|
fingpt/FinGPT_Benchmark/benchmarks/convfinqa.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from seqeval.metrics import accuracy_score
|
| 2 |
+
from datasets import load_dataset, load_from_disk
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import datasets
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
from functools import partial
|
| 8 |
+
import re
|
| 9 |
+
import sys
|
| 10 |
+
import numpy as np
|
| 11 |
+
from fingpt.FinGPT_Benchmark.utils import *
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
sys.path.append('../')
|
| 14 |
+
|
| 15 |
+
def cvt_text_to_pred(text):
|
| 16 |
+
if not text:
|
| 17 |
+
return 'nan'
|
| 18 |
+
pred_match = re.search(r'\d+(.\d+)', text)
|
| 19 |
+
if pred_match is not None:
|
| 20 |
+
pred = pred_match.group()
|
| 21 |
+
else:
|
| 22 |
+
print(text)
|
| 23 |
+
pred = '0.0'
|
| 24 |
+
return pred
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def map_output(feature):
|
| 28 |
+
|
| 29 |
+
label = cvt_text_to_pred(feature['output'])
|
| 30 |
+
pred = cvt_text_to_pred(feature['out_text'])
|
| 31 |
+
|
| 32 |
+
return {'label': label, 'pred': pred}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
dataset = load_from_disk(Path(__file__).parent.parent / 'data/fingpt-convfinqa')['test']
|
| 37 |
+
dataset = dataset.map(partial(test_mapping, args), load_from_cache_file=False)
|
| 38 |
+
|
| 39 |
+
def collate_fn(batch):
|
| 40 |
+
inputs = tokenizer(
|
| 41 |
+
[f["prompt"] for f in batch], return_tensors='pt',
|
| 42 |
+
padding=True, max_length=args.max_length,
|
| 43 |
+
return_token_type_ids=False
|
| 44 |
+
)
|
| 45 |
+
return inputs
|
| 46 |
+
|
| 47 |
+
dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False)
|
| 48 |
+
|
| 49 |
+
out_text_list = []
|
| 50 |
+
log_interval = len(dataloader) // 5
|
| 51 |
+
|
| 52 |
+
for idx, inputs in enumerate(tqdm(dataloader)):
|
| 53 |
+
inputs = {key: value.to(model.device) for key, value in inputs.items()}
|
| 54 |
+
res = model.generate(**inputs, max_length=args.max_length, eos_token_id=tokenizer.eos_token_id)
|
| 55 |
+
res_sentences = [tokenizer.decode(i, skip_special_tokens=True) for i in res]
|
| 56 |
+
if (idx + 1) % log_interval == 0:
|
| 57 |
+
tqdm.write(f'{idx}: {res_sentences[0]}')
|
| 58 |
+
out_text = [o.split("Answer: ")[1] if "Answer: " in o else "" for o in res_sentences]
|
| 59 |
+
out_text_list += out_text
|
| 60 |
+
torch.cuda.empty_cache()
|
| 61 |
+
|
| 62 |
+
dataset = dataset.add_column("out_text", out_text_list)
|
| 63 |
+
dataset = dataset.map(map_output, load_from_cache_file=False)
|
| 64 |
+
dataset = dataset.filter(lambda x: x['pred'] != 'nan')
|
| 65 |
+
dataset = dataset.to_pandas()
|
| 66 |
+
|
| 67 |
+
print(dataset)
|
| 68 |
+
dataset.to_csv('tmp.csv')
|
| 69 |
+
|
| 70 |
+
label = [float(d) for d in dataset['label']]
|
| 71 |
+
pred = [float(d) for d in dataset['pred']]
|
| 72 |
+
|
| 73 |
+
print('Accuracy: ', accuracy_score(label, pred))
|
| 74 |
+
|
| 75 |
+
return dataset
|
fingpt/FinGPT_Benchmark/benchmarks/evaluate.sh
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# export TRANSFORMERS_NO_ADVISORY_WARNINGS=1
|
| 2 |
+
# export TOKENIZERS_PARALLELISM=0
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
#---- Relation Extraction ----
|
| 8 |
+
|
| 9 |
+
python benchmarks.py \
|
| 10 |
+
--dataset re \
|
| 11 |
+
--base_model llama2 \
|
| 12 |
+
--peft_model ../finetuned_models/finred-llama2-linear_202310012254 \
|
| 13 |
+
--batch_size 8 \
|
| 14 |
+
--max_length 512
|
| 15 |
+
|
| 16 |
+
# python benchmarks.py \
|
| 17 |
+
# --dataset re \
|
| 18 |
+
# --base_model chatglm2 \
|
| 19 |
+
# --peft_model ../finetuned_models/finred-chatglm2-linear_202310010213 \
|
| 20 |
+
# --batch_size 8 \
|
| 21 |
+
# --max_length 512
|
| 22 |
+
|
| 23 |
+
# python benchmarks.py \
|
| 24 |
+
# --dataset re \
|
| 25 |
+
# --base_model qwen \
|
| 26 |
+
# --peft_model ../finetuned_models/finred-qwen-linear_202310010502 \
|
| 27 |
+
# --batch_size 8 \
|
| 28 |
+
# --max_length 512
|
| 29 |
+
|
| 30 |
+
# python benchmarks.py \
|
| 31 |
+
# --dataset re \
|
| 32 |
+
# --base_model mpt \
|
| 33 |
+
# --peft_model ../finetuned_models/finred-mpt-linear_202310010641 \
|
| 34 |
+
# --batch_size 8 \
|
| 35 |
+
# --max_length 512
|
| 36 |
+
|
| 37 |
+
# python benchmarks.py \
|
| 38 |
+
# --dataset re \
|
| 39 |
+
# --base_model bloom \
|
| 40 |
+
# --peft_model ../finetuned_models/finred-bloom-linear_202310010741 \
|
| 41 |
+
# --batch_size 8 \
|
| 42 |
+
# --max_length 512
|
| 43 |
+
|
| 44 |
+
# python benchmarks.py \
|
| 45 |
+
# --dataset re \
|
| 46 |
+
# --base_model falcon \
|
| 47 |
+
# --peft_model ../finetuned_models/finred-falcon-linear_202310010333 \
|
| 48 |
+
# --batch_size 1 \
|
| 49 |
+
# --max_length 512
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
#---- Generalization ----
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# python benchmarks.py \
|
| 56 |
+
# --dataset fiqa_mlt \
|
| 57 |
+
# --base_model falcon \
|
| 58 |
+
# --peft_model ../finetuned_models/GRCLS-sentiment-falcon-linear-small_202309291801/checkpoint-300 \
|
| 59 |
+
# --batch_size 8 \
|
| 60 |
+
# --max_length 512
|
| 61 |
+
|
| 62 |
+
# python benchmarks.py \
|
| 63 |
+
# --dataset fpb_mlt \
|
| 64 |
+
# --base_model llama2 \
|
| 65 |
+
# --peft_model ../finetuned_models/GRCLS-sentiment-llama2-linear-small_202309290356/checkpoint-800 \
|
| 66 |
+
# --batch_size 8 \
|
| 67 |
+
# --max_length 512
|
| 68 |
+
|
| 69 |
+
# python benchmarks.py \
|
| 70 |
+
# --dataset fiqa_mlt \
|
| 71 |
+
# --base_model qwen \
|
| 72 |
+
# --peft_model ../finetuned_models/GRCLS-sentiment-qwen-linear-small_202309292115/checkpoint-700 \
|
| 73 |
+
# --batch_size 8 \
|
| 74 |
+
# --max_length 512
|
| 75 |
+
|
| 76 |
+
# python benchmarks.py \
|
| 77 |
+
# --dataset fpb_mlt \
|
| 78 |
+
# --base_model mpt \
|
| 79 |
+
# --peft_model ../finetuned_models/GRCLS-sentiment-mpt-linear-small_202309300359/checkpoint-400 \
|
| 80 |
+
# --batch_size 8 \
|
| 81 |
+
# --max_length 512
|
| 82 |
+
|
| 83 |
+
# python benchmarks.py \
|
| 84 |
+
# --dataset fiqa_mlt \
|
| 85 |
+
# --base_model chatglm2 \
|
| 86 |
+
# --peft_model ../finetuned_models/GRCLS-sentiment-chatglm2-linear-1e-4lr_202309280440/checkpoint-212 \
|
| 87 |
+
# --batch_size 8 \
|
| 88 |
+
# --max_length 512
|
| 89 |
+
|
| 90 |
+
# python benchmarks.py \
|
| 91 |
+
# --dataset fiqa_mlt \
|
| 92 |
+
# --base_model bloom \
|
| 93 |
+
# --peft_model ../finetuned_models/GRCLS-sentiment-bloom-linear-small_202309300044/checkpoint-500 \
|
| 94 |
+
# --batch_size 8 \
|
| 95 |
+
# --max_length 512
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
#---- Multi-Task ----
|
| 101 |
+
|
| 102 |
+
# python benchmarks.py \
|
| 103 |
+
# --dataset re \
|
| 104 |
+
# --base_model chatglm2 \
|
| 105 |
+
# --peft_model ../finetuned_models/MT-chatglm2-linear_202309201120 \
|
| 106 |
+
# --batch_size 8 \
|
| 107 |
+
# --max_length 512
|
| 108 |
+
|
| 109 |
+
# python benchmarks.py \
|
| 110 |
+
# --dataset re \
|
| 111 |
+
# --base_model falcon \
|
| 112 |
+
# --peft_model ../finetuned_models/MT-falcon-linear_202309210126 \
|
| 113 |
+
# --batch_size 8 \
|
| 114 |
+
# --max_length 512
|
| 115 |
+
|
| 116 |
+
# python benchmarks.py \
|
| 117 |
+
# --dataset re \
|
| 118 |
+
# --base_model bloom \
|
| 119 |
+
# --peft_model ../finetuned_models/MT-bloom-linear_202309211510 \
|
| 120 |
+
# --batch_size 8 \
|
| 121 |
+
# --max_length 512
|
| 122 |
+
|
| 123 |
+
# python benchmarks.py \
|
| 124 |
+
# --dataset re \
|
| 125 |
+
# --base_model qwen \
|
| 126 |
+
# --peft_model ../finetuned_models/MT-qwen-linear_202309221011 \
|
| 127 |
+
# --batch_size 8 \
|
| 128 |
+
# --max_length 512
|
| 129 |
+
|
| 130 |
+
# python benchmarks.py \
|
| 131 |
+
# --dataset re \
|
| 132 |
+
# --base_model mpt \
|
| 133 |
+
# --peft_model ../finetuned_models/MT-mpt-linear_202309230221 \
|
| 134 |
+
# --batch_size 8 \
|
| 135 |
+
# --max_length 512
|
| 136 |
+
|
| 137 |
+
# python benchmarks.py \
|
| 138 |
+
# --dataset re \
|
| 139 |
+
# --base_model llama2 \
|
| 140 |
+
# --peft_model ../finetuned_models/MT-llama2-linear_202309241345 \
|
| 141 |
+
# --batch_size 8 \
|
| 142 |
+
# --max_length 512
|
| 143 |
+
|
| 144 |
+
# python benchmarks.py \
|
| 145 |
+
# --dataset fpb,fiqa,tfns,nwgi,headline,ner,re \
|
| 146 |
+
# --base_model chatglm2 \
|
| 147 |
+
# --peft_model ../finetuned_models/MT-chatglm2-linear_202309201120 \
|
| 148 |
+
# --batch_size 8 \
|
| 149 |
+
# --max_length 512
|
| 150 |
+
|
| 151 |
+
# python benchmarks.py \
|
| 152 |
+
# --dataset fpb,fiqa,tfns,nwgi,headline,ner,re \
|
| 153 |
+
# --base_model falcon \
|
| 154 |
+
# --peft_model ../finetuned_models/MT-falcon-linear_202309210126 \
|
| 155 |
+
# --batch_size 8 \
|
| 156 |
+
# --max_length 512
|
| 157 |
+
|
| 158 |
+
# python benchmarks.py \
|
| 159 |
+
# --dataset fpb,fiqa,tfns,nwgi,headline,ner,re \
|
| 160 |
+
# --base_model bloom \
|
| 161 |
+
# --peft_model ../finetuned_models/MT-bloom-linear_202309211510 \
|
| 162 |
+
# --batch_size 8 \
|
| 163 |
+
# --max_length 512
|
| 164 |
+
|
| 165 |
+
# python benchmarks.py \
|
| 166 |
+
# --dataset fpb,fiqa,tfns,nwgi,headline,ner,re \
|
| 167 |
+
# --base_model qwen \
|
| 168 |
+
# --peft_model ../finetuned_models/MT-qwen-linear_202309221011 \
|
| 169 |
+
# --batch_size 8 \
|
| 170 |
+
# --max_length 512
|
| 171 |
+
|
| 172 |
+
# python benchmarks.py \
|
| 173 |
+
# --dataset fpb,fiqa,tfns,nwgi,headline,ner,re \
|
| 174 |
+
# --base_model mpt \
|
| 175 |
+
# --peft_model ../finetuned_models/MT-mpt-linear_202309230221 \
|
| 176 |
+
# --batch_size 8 \
|
| 177 |
+
# --max_length 512
|
| 178 |
+
|
| 179 |
+
# python benchmarks.py \
|
| 180 |
+
# --dataset fpb,fiqa,tfns,nwgi,headline,ner,re \
|
| 181 |
+
# --base_model llama2 \
|
| 182 |
+
# --peft_model ../finetuned_models/MT-llama2-linear_202309241345 \
|
| 183 |
+
# --batch_size 8 \
|
| 184 |
+
# --max_length 512
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
#---- ConvFinQA ----
|
| 188 |
+
|
| 189 |
+
# python benchmarks.py \
|
| 190 |
+
# --dataset convfinqa \
|
| 191 |
+
# --base_model falcon \
|
| 192 |
+
# --peft_model ../finetuned_models/convfinqa-falcon-linear_202309170614 \
|
| 193 |
+
# --batch_size 1 \
|
| 194 |
+
# --max_length 2048
|
| 195 |
+
|
| 196 |
+
# python benchmarks.py \
|
| 197 |
+
# --dataset convfinqa \
|
| 198 |
+
# --base_model chatglm2 \
|
| 199 |
+
# --peft_model ../finetuned_models/convfinqa-chatglm2-linear_202309170247 \
|
| 200 |
+
# --batch_size 1 \
|
| 201 |
+
# --max_length 2048
|
| 202 |
+
|
| 203 |
+
# python benchmarks.py \
|
| 204 |
+
# --dataset convfinqa \
|
| 205 |
+
# --base_model qwen \
|
| 206 |
+
# --peft_model ../finetuned_models/convfinqa-qwen-linear_202309171029 \
|
| 207 |
+
# --batch_size 1 \
|
| 208 |
+
# --max_length 2048
|
| 209 |
+
|
| 210 |
+
# python benchmarks.py \
|
| 211 |
+
# --dataset convfinqa \
|
| 212 |
+
# --base_model bloom \
|
| 213 |
+
# --peft_model ../finetuned_models/convfinqa-bloom-linear_202309171502 \
|
| 214 |
+
# --batch_size 1 \
|
| 215 |
+
# --max_length 2048
|
| 216 |
+
|
| 217 |
+
# python benchmarks.py \
|
| 218 |
+
# --dataset convfinqa \
|
| 219 |
+
# --base_model llama2 \
|
| 220 |
+
# --peft_model ../finetuned_models/convfinqa-llama2-linear_202309162205 \
|
| 221 |
+
# --batch_size 1 \
|
| 222 |
+
# --max_length 2048
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
#---- FinEval ----
|
| 226 |
+
|
| 227 |
+
# python benchmarks.py \
|
| 228 |
+
# --dataset fineval \
|
| 229 |
+
# --base_model falcon \
|
| 230 |
+
# --peft_model ../finetuned_models/fineval-falcon-linear_202309220409 \
|
| 231 |
+
# --batch_size 1
|
| 232 |
+
|
| 233 |
+
# python benchmarks.py \
|
| 234 |
+
# --dataset fineval \
|
| 235 |
+
# --base_model chatglm2 \
|
| 236 |
+
# --peft_model ../finetuned_models/fineval-chatglm2-linear_202309220332 \
|
| 237 |
+
# --batch_size 1
|
| 238 |
+
|
| 239 |
+
# python benchmarks.py \
|
| 240 |
+
# --dataset fineval \
|
| 241 |
+
# --base_model qwen \
|
| 242 |
+
# --peft_model ../finetuned_models/fineval-qwen-linear_202309220508 \
|
| 243 |
+
# --batch_size 1
|
| 244 |
+
|
| 245 |
+
# python benchmarks.py \
|
| 246 |
+
# --dataset fineval \
|
| 247 |
+
# --base_model bloom \
|
| 248 |
+
# --peft_model ../finetuned_models/fineval-bloom-linear_202309220639 \
|
| 249 |
+
# --batch_size 1
|
| 250 |
+
|
| 251 |
+
# python benchmarks.py \
|
| 252 |
+
# --dataset fineval \
|
| 253 |
+
# --base_model mpt \
|
| 254 |
+
# --peft_model ../finetuned_models/fineval-mpt-linear_202309220555 \
|
| 255 |
+
# --batch_size 1
|
| 256 |
+
|
| 257 |
+
# python benchmarks.py \
|
| 258 |
+
# --dataset fineval \
|
| 259 |
+
# --base_model llama2 \
|
| 260 |
+
# --peft_model ../finetuned_models/fineval-llama2-linear_202309192232 \
|
| 261 |
+
# --batch_size 1
|
| 262 |
+
|
| 263 |
+
# python benchmarks.py \
|
| 264 |
+
# --dataset fineval \
|
| 265 |
+
# --base_model internlm \
|
| 266 |
+
# --peft_model ../finetuned_models/fineval-internlm-linear_202309211248 \
|
| 267 |
+
# --batch_size 1
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
#---- NER ----
|
| 271 |
+
|
| 272 |
+
# python benchmarks.py \
|
| 273 |
+
# --dataset ner \
|
| 274 |
+
# --base_model falcon \
|
| 275 |
+
# --peft_model ../finetuned_models/ner-falcon-linear_202309160320 \
|
| 276 |
+
# --batch_size 1
|
| 277 |
+
|
| 278 |
+
# python benchmarks.py \
|
| 279 |
+
# --dataset ner \
|
| 280 |
+
# --base_model chatglm2 \
|
| 281 |
+
# --peft_model ../finetuned_models/ner-chatglm2-linear_202309160238 \
|
| 282 |
+
# --batch_size 1
|
| 283 |
+
|
| 284 |
+
# python benchmarks.py \
|
| 285 |
+
# --dataset ner \
|
| 286 |
+
# --base_model qwen \
|
| 287 |
+
# --peft_model ../finetuned_models/ner-qwen-linear_202309160409 \
|
| 288 |
+
# --batch_size 1
|
| 289 |
+
|
| 290 |
+
# python benchmarks.py \
|
| 291 |
+
# --dataset ner \
|
| 292 |
+
# --base_model bloom \
|
| 293 |
+
# --peft_model ../finetuned_models/ner-bloom-linear_202309160530 \
|
| 294 |
+
# --batch_size 1
|
| 295 |
+
|
| 296 |
+
# python benchmarks.py \
|
| 297 |
+
# --dataset ner \
|
| 298 |
+
# --base_model mpt \
|
| 299 |
+
# --peft_model ../finetuned_models/ner-mpt-linear_202309160459 \
|
| 300 |
+
# --batch_size 1
|
| 301 |
+
|
| 302 |
+
# python benchmarks.py \
|
| 303 |
+
# --dataset ner \
|
| 304 |
+
# --base_model llama2 \
|
| 305 |
+
# --peft_model ../finetuned_models/ner-llama2-linear_202309161924 \
|
| 306 |
+
# --batch_size 1
|
| 307 |
+
|
| 308 |
+
#---- sentiment analysis ----
|
| 309 |
+
|
| 310 |
+
# python benchmarks.py \
|
| 311 |
+
# --dataset fpb,fiqa,tfns,nwgi \
|
| 312 |
+
# --base_model llama2 \
|
| 313 |
+
# --peft_model ../finetuned_models/sentiment-llama2-linear_202309130723 \
|
| 314 |
+
# --batch_size 8
|
| 315 |
+
|
| 316 |
+
# python benchmarks.py \
|
| 317 |
+
# --dataset fpb,fiqa,tfns,nwgi \
|
| 318 |
+
# --base_model falcon \
|
| 319 |
+
# --peft_model ../finetuned_models/sentiment-falcon-default_20230911055454 \
|
| 320 |
+
# --batch_size 8
|
| 321 |
+
|
| 322 |
+
# python benchmarks.py \
|
| 323 |
+
# --dataset fpb,fiqa,tfns,nwgi \
|
| 324 |
+
# --base_model chatglm2 \
|
| 325 |
+
# --peft_model ../finetuned_models/sentiment-chatglm2-default_20230910031650 \
|
| 326 |
+
# --batch_size 8
|
| 327 |
+
|
| 328 |
+
# python benchmarks.py \
|
| 329 |
+
# --dataset fpb,fiqa,tfns,nwgi \
|
| 330 |
+
# --base_model qwen \
|
| 331 |
+
# --peft_model ../finetuned_models/sentiment-qwen-linear_202309132016 \
|
| 332 |
+
# --batch_size 8
|
| 333 |
+
|
| 334 |
+
# python benchmarks.py \
|
| 335 |
+
# --dataset fpb,fiqa,tfns,nwgi \
|
| 336 |
+
# --base_model internlm \
|
| 337 |
+
# --peft_model ../finetuned_models/sentiment-internlm-linear_202309130230 \
|
| 338 |
+
# --batch_size 8
|
| 339 |
+
|
| 340 |
+
# python benchmarks.py \
|
| 341 |
+
# --dataset fpb,fiqa,tfns,nwgi \
|
| 342 |
+
# --base_model bloom \
|
| 343 |
+
# --peft_model ../finetuned_models/sentiment-bloom-linear_202309151934 \
|
| 344 |
+
# --batch_size 8
|
| 345 |
+
|
| 346 |
+
# python benchmarks.py \
|
| 347 |
+
# --dataset fpb,fiqa,tfns,nwgi \
|
| 348 |
+
# --base_model mpt \
|
| 349 |
+
# --peft_model ../finetuned_models/sentiment-mpt-linear_202309151405 \
|
| 350 |
+
# --batch_size 8
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
#---- headline ----
|
| 354 |
+
|
| 355 |
+
# python benchmarks.py \
|
| 356 |
+
# --dataset headline \
|
| 357 |
+
# --base_model llama2 \
|
| 358 |
+
# --peft_model ../finetuned_models/headline-llama2-linear_202309140611 \
|
| 359 |
+
# --batch_size 8
|
| 360 |
+
|
| 361 |
+
# python benchmarks.py \
|
| 362 |
+
# --dataset headline \
|
| 363 |
+
# --base_model chatglm2 \
|
| 364 |
+
# --peft_model ../finetuned_models/headline-chatglm2-linear_202309140941 \
|
| 365 |
+
# --batch_size 8
|
| 366 |
+
|
| 367 |
+
# python benchmarks.py \
|
| 368 |
+
# --dataset headline \
|
| 369 |
+
# --base_model internlm \
|
| 370 |
+
# --peft_model ../finetuned_models/headline-internlm-linear_202309140308 \
|
| 371 |
+
# --batch_size 8
|
| 372 |
+
|
| 373 |
+
# python benchmarks.py \
|
| 374 |
+
# --dataset headline \
|
| 375 |
+
# --base_model falcon \
|
| 376 |
+
# --peft_model ../finetuned_models/headline-falcon-linear_202309141852 \
|
| 377 |
+
# --batch_size 8
|
| 378 |
+
|
| 379 |
+
# python benchmarks.py \
|
| 380 |
+
# --dataset headline \
|
| 381 |
+
# --base_model qwen \
|
| 382 |
+
# --peft_model ../finetuned_models/headline-qwen-linear_202309142156 \
|
| 383 |
+
# --batch_size 8
|
| 384 |
+
|
| 385 |
+
# python benchmarks.py \
|
| 386 |
+
# --dataset headline \
|
| 387 |
+
# --base_model mpt \
|
| 388 |
+
# --peft_model ../finetuned_models/headline-mpt-linear_202309150151 \
|
| 389 |
+
# --batch_size 8
|
| 390 |
+
|
| 391 |
+
# python benchmarks.py \
|
| 392 |
+
# --dataset headline \
|
| 393 |
+
# --base_model bloom \
|
| 394 |
+
# --peft_model ../finetuned_models/headline-bloom-linear_202309151641 \
|
| 395 |
+
# --batch_size 8
|
fingpt/FinGPT_Benchmark/benchmarks/fineval.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from seqeval.metrics import accuracy_score
|
| 2 |
+
from datasets import load_dataset, load_from_disk
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import datasets
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
from functools import partial
|
| 8 |
+
import re
|
| 9 |
+
import sys
|
| 10 |
+
import numpy as np
|
| 11 |
+
from fingpt.FinGPT_Benchmark.utils import *
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
sys.path.append('../')
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def cvt_text_to_pred(text):
|
| 17 |
+
|
| 18 |
+
pred_match = re.search(r'[ABCD]', text)
|
| 19 |
+
if pred_match is not None:
|
| 20 |
+
pred = pred_match.group()
|
| 21 |
+
pred = ["A", "B", "C", "D"].index(pred)
|
| 22 |
+
else:
|
| 23 |
+
pred = -1
|
| 24 |
+
return pred
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def map_output(feature):
|
| 28 |
+
|
| 29 |
+
label = cvt_text_to_pred(feature['output'])
|
| 30 |
+
pred = cvt_text_to_pred(feature['out_text'])
|
| 31 |
+
|
| 32 |
+
return {'label': label, 'pred': pred}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def test_fineval(args, model, tokenizer):
|
| 36 |
+
|
| 37 |
+
dataset = load_from_disk(Path(__file__).parent.parent / 'data/fingpt-fineval')['test']
|
| 38 |
+
dataset = dataset.map(partial(test_mapping, args), load_from_cache_file=False)
|
| 39 |
+
|
| 40 |
+
def collate_fn(batch):
|
| 41 |
+
inputs = tokenizer(
|
| 42 |
+
[f["prompt"] for f in batch], return_tensors='pt',
|
| 43 |
+
padding=True, max_length=args.max_length,
|
| 44 |
+
return_token_type_ids=False
|
| 45 |
+
)
|
| 46 |
+
return inputs
|
| 47 |
+
|
| 48 |
+
dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False)
|
| 49 |
+
|
| 50 |
+
out_text_list = []
|
| 51 |
+
log_interval = len(dataloader) // 5
|
| 52 |
+
|
| 53 |
+
for idx, inputs in enumerate(tqdm(dataloader)):
|
| 54 |
+
inputs = {key: value.to(model.device) for key, value in inputs.items()}
|
| 55 |
+
res = model.generate(**inputs, max_length=args.max_length, eos_token_id=tokenizer.eos_token_id)
|
| 56 |
+
res_sentences = [tokenizer.decode(i, skip_special_tokens=True) for i in res]
|
| 57 |
+
if (idx + 1) % log_interval == 0:
|
| 58 |
+
tqdm.write(f'{idx}: {res_sentences[0]}')
|
| 59 |
+
out_text = [o.split("Answer: ")[1] for o in res_sentences]
|
| 60 |
+
out_text_list += out_text
|
| 61 |
+
torch.cuda.empty_cache()
|
| 62 |
+
|
| 63 |
+
dataset = dataset.add_column("out_text", out_text_list)
|
| 64 |
+
dataset = dataset.map(map_output, load_from_cache_file=False)
|
| 65 |
+
dataset = dataset.to_pandas()
|
| 66 |
+
|
| 67 |
+
print(dataset)
|
| 68 |
+
dataset.to_csv('tmp.csv')
|
| 69 |
+
|
| 70 |
+
print('Accuracy:', accuracy_score(dataset['label'], dataset['pred']))
|
| 71 |
+
|
| 72 |
+
return dataset
|
fingpt/FinGPT_Benchmark/benchmarks/finred.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from seqeval.metrics import classification_report
|
| 2 |
+
from datasets import load_dataset, load_from_disk
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import datasets
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
from functools import partial
|
| 8 |
+
import re
|
| 9 |
+
import sys
|
| 10 |
+
import numpy as np
|
| 11 |
+
from fingpt.FinGPT_Benchmark.utils import *
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
sys.path.append('../')
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
relations = [
|
| 17 |
+
'product_or_material_produced',
|
| 18 |
+
'manufacturer',
|
| 19 |
+
'distributed_by',
|
| 20 |
+
'industry',
|
| 21 |
+
'position_held',
|
| 22 |
+
'original_broadcaster',
|
| 23 |
+
'owned_by',
|
| 24 |
+
'founded_by',
|
| 25 |
+
'distribution_format',
|
| 26 |
+
'headquarters_location',
|
| 27 |
+
'stock_exchange',
|
| 28 |
+
'currency',
|
| 29 |
+
'parent_organization',
|
| 30 |
+
'chief_executive_officer',
|
| 31 |
+
'director_/_manager',
|
| 32 |
+
'owner_of',
|
| 33 |
+
'operator',
|
| 34 |
+
'member_of',
|
| 35 |
+
'employer',
|
| 36 |
+
'chairperson',
|
| 37 |
+
'platform',
|
| 38 |
+
'subsidiary',
|
| 39 |
+
'legal_form',
|
| 40 |
+
'publisher',
|
| 41 |
+
'developer',
|
| 42 |
+
'brand',
|
| 43 |
+
'business_division',
|
| 44 |
+
'location_of_formation',
|
| 45 |
+
'creator',
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def cvt_text_to_pred(ref, text):
|
| 50 |
+
|
| 51 |
+
preds = []
|
| 52 |
+
for pred_txt in text.strip('.').split(';'):
|
| 53 |
+
pred_match = re.match(r'^(.*):(.*),(.*)$', pred_txt)
|
| 54 |
+
if pred_match is not None:
|
| 55 |
+
relation, word1, word2 = pred_match.group(1).strip(), pred_match.group(2).strip(), pred_match.group(3).strip()
|
| 56 |
+
if relation in relations and word1 in ref and word2 in ref:
|
| 57 |
+
preds.append((relation, word1, word2))
|
| 58 |
+
else:
|
| 59 |
+
print("Not found Error: ", relation, word1, word2, ref)
|
| 60 |
+
else:
|
| 61 |
+
print("Parse Error: ", pred_txt)
|
| 62 |
+
|
| 63 |
+
return preds
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def map_output(feature):
|
| 67 |
+
|
| 68 |
+
ref = feature['input']
|
| 69 |
+
label = cvt_text_to_pred(ref, feature['output'])
|
| 70 |
+
pred = cvt_text_to_pred(ref, feature['out_text'])
|
| 71 |
+
|
| 72 |
+
return {'label': label, 'pred': pred}
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def calc_metric(gt_list, pred_list):
|
| 76 |
+
# Initialize variables for true positives, false positives, and false negatives
|
| 77 |
+
true_positives = 0
|
| 78 |
+
false_positives = 0
|
| 79 |
+
false_negatives = 0
|
| 80 |
+
|
| 81 |
+
for (ground_truth, predicted_relations) in zip(gt_list, pred_list):
|
| 82 |
+
# Calculate true positives, false positives, and false negatives
|
| 83 |
+
for relation in predicted_relations:
|
| 84 |
+
if relation in ground_truth:
|
| 85 |
+
true_positives += 1
|
| 86 |
+
else:
|
| 87 |
+
false_positives += 1
|
| 88 |
+
|
| 89 |
+
for relation in ground_truth:
|
| 90 |
+
if relation not in predicted_relations:
|
| 91 |
+
false_negatives += 1
|
| 92 |
+
|
| 93 |
+
# Calculate precision, recall, and F1-Score
|
| 94 |
+
precision = true_positives / (true_positives + false_positives)
|
| 95 |
+
recall = true_positives / (true_positives + false_negatives)
|
| 96 |
+
f1_score = 2 * (precision * recall) / (precision + recall)
|
| 97 |
+
|
| 98 |
+
# Print the results
|
| 99 |
+
print("Precision:", precision)
|
| 100 |
+
print("Recall:", recall)
|
| 101 |
+
print("F1-Score:", f1_score)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def test_re(args, model, tokenizer):
|
| 105 |
+
|
| 106 |
+
dataset = load_from_disk(Path(__file__).parent.parent / 'data/fingpt-finred-re')['test']
|
| 107 |
+
dataset = dataset.train_test_split(0.2, seed=42)['test']
|
| 108 |
+
dataset = dataset.map(partial(test_mapping, args), load_from_cache_file=False)
|
| 109 |
+
|
| 110 |
+
def collate_fn(batch):
|
| 111 |
+
inputs = tokenizer(
|
| 112 |
+
[f["prompt"] for f in batch], return_tensors='pt',
|
| 113 |
+
padding=True, max_length=args.max_length,
|
| 114 |
+
return_token_type_ids=False
|
| 115 |
+
)
|
| 116 |
+
return inputs
|
| 117 |
+
|
| 118 |
+
dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False)
|
| 119 |
+
|
| 120 |
+
out_text_list = []
|
| 121 |
+
log_interval = len(dataloader) // 5
|
| 122 |
+
|
| 123 |
+
for idx, inputs in enumerate(tqdm(dataloader)):
|
| 124 |
+
inputs = {key: value.to(model.device) for key, value in inputs.items()}
|
| 125 |
+
res = model.generate(**inputs, max_length=args.max_length, eos_token_id=tokenizer.eos_token_id, max_new_tokens=128)
|
| 126 |
+
res_sentences = [tokenizer.decode(i, skip_special_tokens=True) for i in res]
|
| 127 |
+
if (idx + 1) % log_interval == 0:
|
| 128 |
+
tqdm.write(f'{idx}: {res_sentences[0]}')
|
| 129 |
+
out_text = [o.split("Answer: ")[1] for o in res_sentences]
|
| 130 |
+
out_text_list += out_text
|
| 131 |
+
torch.cuda.empty_cache()
|
| 132 |
+
|
| 133 |
+
dataset = dataset.add_column("out_text", out_text_list)
|
| 134 |
+
dataset = dataset.map(map_output, load_from_cache_file=False)
|
| 135 |
+
dataset = dataset.to_pandas()
|
| 136 |
+
|
| 137 |
+
print(dataset)
|
| 138 |
+
dataset.to_csv('tmp.csv')
|
| 139 |
+
|
| 140 |
+
label = [[tuple(t) for t in d.tolist()] for d in dataset['label']]
|
| 141 |
+
pred = [[tuple(t) for t in d.tolist()] for d in dataset['pred']]
|
| 142 |
+
|
| 143 |
+
label_re = [[t[0] for t in d.tolist()] for d in dataset['label']]
|
| 144 |
+
pred_re = [[t[0] for t in d.tolist()] for d in dataset['pred']]
|
| 145 |
+
|
| 146 |
+
calc_metric(label, pred)
|
| 147 |
+
|
| 148 |
+
calc_metric(label_re, pred_re)
|
| 149 |
+
|
| 150 |
+
return dataset
|
fingpt/FinGPT_Benchmark/benchmarks/fiqa.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
warnings.filterwarnings("ignore")
|
| 3 |
+
|
| 4 |
+
from sklearn.metrics import accuracy_score,f1_score
|
| 5 |
+
from datasets import load_dataset, load_from_disk, Dataset
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import datasets
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
from functools import partial
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
with open(Path(__file__).parent / 'sentiment_templates.txt') as f:
|
| 16 |
+
templates = [l.strip() for l in f.readlines()]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def format_example(example: dict) -> dict:
|
| 20 |
+
context = f"Instruction: {example['instruction']}\n"
|
| 21 |
+
if example.get("input"):
|
| 22 |
+
context += f"Input: {example['input']}\n"
|
| 23 |
+
context += "Answer: "
|
| 24 |
+
target = example["output"]
|
| 25 |
+
return {"context": context, "target": target}
|
| 26 |
+
|
| 27 |
+
def add_instructions(x):
|
| 28 |
+
if x.format == "post":
|
| 29 |
+
return "What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}."
|
| 30 |
+
else:
|
| 31 |
+
return "What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}."
|
| 32 |
+
|
| 33 |
+
def make_label(x):
|
| 34 |
+
if x < - 0.1: return "negative"
|
| 35 |
+
elif x >=-0.1 and x < 0.1: return "neutral"
|
| 36 |
+
elif x >= 0.1: return "positive"
|
| 37 |
+
|
| 38 |
+
def change_target(x):
|
| 39 |
+
if 'positive' in x or 'Positive' in x:
|
| 40 |
+
return 'positive'
|
| 41 |
+
elif 'negative' in x or 'Negative' in x:
|
| 42 |
+
return 'negative'
|
| 43 |
+
else:
|
| 44 |
+
return 'neutral'
|
| 45 |
+
|
| 46 |
+
def vote_output(x):
|
| 47 |
+
output_dict = {'positive': 0, 'negative': 0, 'neutral': 0}
|
| 48 |
+
for i in range(len(templates)):
|
| 49 |
+
pred = change_target(x[f'out_text_{i}'].lower())
|
| 50 |
+
output_dict[pred] += 1
|
| 51 |
+
if output_dict['positive'] > output_dict['negative']:
|
| 52 |
+
return 'positive'
|
| 53 |
+
elif output_dict['negative'] > output_dict['positive']:
|
| 54 |
+
return 'negative'
|
| 55 |
+
else:
|
| 56 |
+
return 'neutral'
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def test_fiqa(args, model, tokenizer, prompt_fun=add_instructions):
|
| 60 |
+
batch_size = args.batch_size
|
| 61 |
+
# dataset = load_dataset('pauri32/fiqa-2018')
|
| 62 |
+
dataset = load_from_disk(Path(__file__).parent.parent / 'data/fiqa-2018/')
|
| 63 |
+
dataset = datasets.concatenate_datasets([dataset["train"], dataset["validation"] ,dataset["test"] ])
|
| 64 |
+
dataset = dataset.train_test_split(0.226, seed = 42)['test']
|
| 65 |
+
dataset = dataset.to_pandas()
|
| 66 |
+
dataset["output"] = dataset.sentiment_score.apply(make_label)
|
| 67 |
+
if prompt_fun is None:
|
| 68 |
+
dataset["instruction"] = "What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}."
|
| 69 |
+
else:
|
| 70 |
+
dataset["instruction"] = dataset.apply(prompt_fun, axis = 1)
|
| 71 |
+
|
| 72 |
+
dataset = dataset[['sentence', 'output',"instruction"]]
|
| 73 |
+
dataset.columns = ["input", "output","instruction"]
|
| 74 |
+
dataset[["context","target"]] = dataset.apply(format_example, axis=1, result_type="expand")
|
| 75 |
+
|
| 76 |
+
# print example
|
| 77 |
+
print(f"\n\nPrompt example:\n{dataset['context'][0]}\n\n")
|
| 78 |
+
|
| 79 |
+
context = dataset['context'].tolist()
|
| 80 |
+
total_steps = dataset.shape[0]//batch_size + 1
|
| 81 |
+
print(f"Total len: {len(context)}. Batchsize: {batch_size}. Total steps: {total_steps}")
|
| 82 |
+
|
| 83 |
+
out_text_list = []
|
| 84 |
+
|
| 85 |
+
for i in tqdm(range(total_steps)):
|
| 86 |
+
tmp_context = context[i* batch_size:(i+1)* batch_size]
|
| 87 |
+
tokens = tokenizer(tmp_context, return_tensors='pt', padding=True, max_length=512, return_token_type_ids=False)
|
| 88 |
+
# tokens.pop('token_type_ids')
|
| 89 |
+
for k in tokens.keys():
|
| 90 |
+
tokens[k] = tokens[k].cuda()
|
| 91 |
+
|
| 92 |
+
res = model.generate(**tokens, max_length=512, eos_token_id=tokenizer.eos_token_id)
|
| 93 |
+
res_sentences = [tokenizer.decode(i, skip_special_tokens=True) for i in res]
|
| 94 |
+
tqdm.write(f'{i}: {res_sentences[0]}')
|
| 95 |
+
out_text = [o.split("Answer: ")[1] for o in res_sentences]
|
| 96 |
+
out_text_list += out_text
|
| 97 |
+
torch.cuda.empty_cache()
|
| 98 |
+
|
| 99 |
+
dataset["out_text"] = out_text_list
|
| 100 |
+
dataset["new_target"] = dataset["target"].apply(change_target)
|
| 101 |
+
dataset["new_out"] = dataset["out_text"].apply(change_target)
|
| 102 |
+
|
| 103 |
+
acc = accuracy_score(dataset["new_target"], dataset["new_out"])
|
| 104 |
+
f1_macro = f1_score(dataset["new_target"], dataset["new_out"], average = "macro")
|
| 105 |
+
f1_micro = f1_score(dataset["new_target"], dataset["new_out"], average = "micro")
|
| 106 |
+
f1_weighted = f1_score(dataset["new_target"], dataset["new_out"], average = "weighted")
|
| 107 |
+
|
| 108 |
+
print(f"Acc: {acc}. F1 macro: {f1_macro}. F1 micro: {f1_micro}. F1 weighted (BloombergGPT): {f1_weighted}. ")
|
| 109 |
+
|
| 110 |
+
return dataset
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def test_fiqa_mlt(args, model, tokenizer):
|
| 114 |
+
batch_size = args.batch_size
|
| 115 |
+
# dataset = load_dataset('pauri32/fiqa-2018')
|
| 116 |
+
dataset = load_from_disk(Path(__file__).parent.parent / 'data/fiqa-2018/')
|
| 117 |
+
dataset = datasets.concatenate_datasets([dataset["train"], dataset["validation"] ,dataset["test"] ])
|
| 118 |
+
dataset = dataset.train_test_split(0.226, seed=42)['test']
|
| 119 |
+
dataset = dataset.to_pandas()
|
| 120 |
+
dataset["output"] = dataset.sentiment_score.apply(make_label)
|
| 121 |
+
dataset["text_type"] = dataset.apply(lambda x: 'tweet' if x.format == "post" else 'news', axis=1)
|
| 122 |
+
dataset = dataset[['sentence', 'output', "text_type"]]
|
| 123 |
+
dataset.columns = ["input", "output", "text_type"]
|
| 124 |
+
|
| 125 |
+
dataset["output"] = dataset["output"].apply(change_target)
|
| 126 |
+
dataset = dataset[dataset["output"] != 'neutral']
|
| 127 |
+
|
| 128 |
+
out_texts_list = [[] for _ in range(len(templates))]
|
| 129 |
+
|
| 130 |
+
def collate_fn(batch):
|
| 131 |
+
inputs = tokenizer(
|
| 132 |
+
[f["context"] for f in batch], return_tensors='pt',
|
| 133 |
+
padding=True, max_length=args.max_length,
|
| 134 |
+
return_token_type_ids=False
|
| 135 |
+
)
|
| 136 |
+
return inputs
|
| 137 |
+
|
| 138 |
+
for i, template in enumerate(templates):
|
| 139 |
+
dataset = dataset[['input', 'output', "text_type"]]
|
| 140 |
+
dataset["instruction"] = dataset['text_type'].apply(lambda x: template.format(type=x) + "\nOptions: positive, negative")
|
| 141 |
+
# dataset["instruction"] = dataset['text_type'].apply(lambda x: template.format(type=x) + "\nOptions: negative, positive")
|
| 142 |
+
dataset[["context", "target"]] = dataset.apply(format_example, axis=1, result_type="expand")
|
| 143 |
+
|
| 144 |
+
dataloader = DataLoader(Dataset.from_pandas(dataset), batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False)
|
| 145 |
+
|
| 146 |
+
log_interval = len(dataloader) // 5
|
| 147 |
+
|
| 148 |
+
for idx, inputs in enumerate(tqdm(dataloader)):
|
| 149 |
+
inputs = {key: value.to(model.device) for key, value in inputs.items()}
|
| 150 |
+
res = model.generate(**inputs, do_sample=False, max_length=args.max_length, eos_token_id=tokenizer.eos_token_id)#, max_new_tokens=10)
|
| 151 |
+
res_sentences = [tokenizer.decode(i, skip_special_tokens=True) for i in res]
|
| 152 |
+
tqdm.write(f'{idx}: {res_sentences[0]}')
|
| 153 |
+
# if (idx + 1) % log_interval == 0:
|
| 154 |
+
# tqdm.write(f'{idx}: {res_sentences[0]}')
|
| 155 |
+
out_text = [o.split("Answer: ")[1] for o in res_sentences]
|
| 156 |
+
out_texts_list[i] += out_text
|
| 157 |
+
torch.cuda.empty_cache()
|
| 158 |
+
|
| 159 |
+
for i in range(len(templates)):
|
| 160 |
+
dataset[f"out_text_{i}"] = out_texts_list[i]
|
| 161 |
+
dataset[f"out_text_{i}"] = dataset[f"out_text_{i}"].apply(change_target)
|
| 162 |
+
|
| 163 |
+
dataset["new_out"] = dataset.apply(vote_output, axis=1, result_type="expand")
|
| 164 |
+
|
| 165 |
+
dataset.to_csv('tmp.csv')
|
| 166 |
+
|
| 167 |
+
for k in [f"out_text_{i}" for i in range(len(templates))] + ["new_out"]:
|
| 168 |
+
|
| 169 |
+
acc = accuracy_score(dataset["target"], dataset[k])
|
| 170 |
+
f1_macro = f1_score(dataset["target"], dataset[k], average="macro")
|
| 171 |
+
f1_micro = f1_score(dataset["target"], dataset[k], average="micro")
|
| 172 |
+
f1_weighted = f1_score(dataset["target"], dataset[k], average="weighted")
|
| 173 |
+
|
| 174 |
+
print(f"Acc: {acc}. F1 macro: {f1_macro}. F1 micro: {f1_micro}. F1 weighted (BloombergGPT): {f1_weighted}. ")
|
| 175 |
+
|
| 176 |
+
return dataset
|
fingpt/FinGPT_Benchmark/benchmarks/fpb.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
warnings.filterwarnings("ignore")
|
| 3 |
+
|
| 4 |
+
from sklearn.metrics import accuracy_score,f1_score
|
| 5 |
+
from datasets import load_dataset, load_from_disk, Dataset
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import datasets
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
from functools import partial
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
dic = {
|
| 15 |
+
0:"negative",
|
| 16 |
+
1:'neutral',
|
| 17 |
+
2:'positive',
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
with open(Path(__file__).parent / 'sentiment_templates.txt') as f:
|
| 21 |
+
templates = [l.strip() for l in f.readlines()]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def format_example(example: dict) -> dict:
|
| 25 |
+
context = f"Instruction: {example['instruction']}\n"
|
| 26 |
+
if example.get("input"):
|
| 27 |
+
context += f"Input: {example['input']}\n"
|
| 28 |
+
context += "Answer: "
|
| 29 |
+
target = example["output"]
|
| 30 |
+
return {"context": context, "target": target}
|
| 31 |
+
|
| 32 |
+
def change_target(x):
|
| 33 |
+
if 'positive' in x or 'Positive' in x:
|
| 34 |
+
return 'positive'
|
| 35 |
+
elif 'negative' in x or 'Negative' in x:
|
| 36 |
+
return 'negative'
|
| 37 |
+
else:
|
| 38 |
+
return 'neutral'
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def vote_output(x):
|
| 42 |
+
output_dict = {'positive': 0, 'negative': 0, 'neutral': 0}
|
| 43 |
+
for i in range(len(templates)):
|
| 44 |
+
pred = change_target(x[f'out_text_{i}'].lower())
|
| 45 |
+
output_dict[pred] += 1
|
| 46 |
+
if output_dict['positive'] > output_dict['negative']:
|
| 47 |
+
return 'positive'
|
| 48 |
+
elif output_dict['negative'] > output_dict['positive']:
|
| 49 |
+
return 'negative'
|
| 50 |
+
else:
|
| 51 |
+
return 'neutral'
|
| 52 |
+
|
| 53 |
+
def test_fpb(args, model, tokenizer, prompt_fun=None):
|
| 54 |
+
batch_size = args.batch_size
|
| 55 |
+
# instructions = load_dataset("financial_phrasebank", "sentences_50agree")
|
| 56 |
+
instructions = load_from_disk(Path(__file__).parent.parent / "data/financial_phrasebank-sentences_50agree/")
|
| 57 |
+
instructions = instructions["train"]
|
| 58 |
+
instructions = instructions.train_test_split(seed = 42)['test']
|
| 59 |
+
instructions = instructions.to_pandas()
|
| 60 |
+
instructions.columns = ["input", "output"]
|
| 61 |
+
instructions["output"] = instructions["output"].apply(lambda x:dic[x])
|
| 62 |
+
|
| 63 |
+
if prompt_fun is None:
|
| 64 |
+
instructions["instruction"] = "What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}."
|
| 65 |
+
else:
|
| 66 |
+
instructions["instruction"] = instructions.apply(prompt_fun, axis = 1)
|
| 67 |
+
|
| 68 |
+
instructions[["context","target"]] = instructions.apply(format_example, axis = 1, result_type="expand")
|
| 69 |
+
|
| 70 |
+
# print example
|
| 71 |
+
print(f"\n\nPrompt example:\n{instructions['context'][0]}\n\n")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
context = instructions['context'].tolist()
|
| 75 |
+
|
| 76 |
+
total_steps = instructions.shape[0]//batch_size + 1
|
| 77 |
+
print(f"Total len: {len(context)}. Batchsize: {batch_size}. Total steps: {total_steps}")
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
out_text_list = []
|
| 81 |
+
for i in tqdm(range(total_steps)):
|
| 82 |
+
tmp_context = context[i* batch_size:(i+1)* batch_size]
|
| 83 |
+
tokens = tokenizer(tmp_context, return_tensors='pt', padding=True, max_length=512, return_token_type_ids=False)
|
| 84 |
+
for k in tokens.keys():
|
| 85 |
+
tokens[k] = tokens[k].cuda()
|
| 86 |
+
res = model.generate(**tokens, max_length=512, eos_token_id=tokenizer.eos_token_id)
|
| 87 |
+
res_sentences = [tokenizer.decode(i, skip_special_tokens=True) for i in res]
|
| 88 |
+
# print(f'{i}: {res_sentences[0]}')
|
| 89 |
+
out_text = [o.split("Answer: ")[1] for o in res_sentences]
|
| 90 |
+
out_text_list += out_text
|
| 91 |
+
torch.cuda.empty_cache()
|
| 92 |
+
|
| 93 |
+
instructions["out_text"] = out_text_list
|
| 94 |
+
instructions["new_target"] = instructions["target"].apply(change_target)
|
| 95 |
+
instructions["new_out"] = instructions["out_text"].apply(change_target)
|
| 96 |
+
|
| 97 |
+
acc = accuracy_score(instructions["new_target"], instructions["new_out"])
|
| 98 |
+
f1_macro = f1_score(instructions["new_target"], instructions["new_out"], average = "macro")
|
| 99 |
+
f1_micro = f1_score(instructions["new_target"], instructions["new_out"], average = "micro")
|
| 100 |
+
f1_weighted = f1_score(instructions["new_target"], instructions["new_out"], average = "weighted")
|
| 101 |
+
|
| 102 |
+
print(f"Acc: {acc}. F1 macro: {f1_macro}. F1 micro: {f1_micro}. F1 weighted (BloombergGPT): {f1_weighted}. ")
|
| 103 |
+
|
| 104 |
+
return instructions
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def test_fpb_mlt(args, model, tokenizer):
|
| 108 |
+
batch_size = args.batch_size
|
| 109 |
+
# dataset = load_dataset("financial_phrasebank", "sentences_50agree")
|
| 110 |
+
dataset = load_from_disk(Path(__file__).parent.parent / 'data/financial_phrasebank-sentences_50agree/')
|
| 111 |
+
dataset = dataset["train"]#.select(range(300))
|
| 112 |
+
dataset = dataset.train_test_split(seed=42)['test']
|
| 113 |
+
dataset = dataset.to_pandas()
|
| 114 |
+
dataset.columns = ["input", "output"]
|
| 115 |
+
dataset["output"] = dataset["output"].apply(lambda x: dic[x])
|
| 116 |
+
dataset["text_type"] = dataset.apply(lambda x: 'news', axis=1)
|
| 117 |
+
|
| 118 |
+
dataset["output"] = dataset["output"].apply(change_target)
|
| 119 |
+
dataset = dataset[dataset["output"] != 'neutral']
|
| 120 |
+
|
| 121 |
+
out_texts_list = [[] for _ in range(len(templates))]
|
| 122 |
+
|
| 123 |
+
def collate_fn(batch):
|
| 124 |
+
inputs = tokenizer(
|
| 125 |
+
[f["context"] for f in batch], return_tensors='pt',
|
| 126 |
+
padding=True, max_length=args.max_length,
|
| 127 |
+
return_token_type_ids=False
|
| 128 |
+
)
|
| 129 |
+
return inputs
|
| 130 |
+
|
| 131 |
+
for i, template in enumerate(templates):
|
| 132 |
+
dataset = dataset[['input', 'output', "text_type"]]
|
| 133 |
+
dataset["instruction"] = dataset['text_type'].apply(lambda x: template.format(type=x) + "\nOptions: positive, negative")
|
| 134 |
+
# dataset["instruction"] = dataset['text_type'].apply(lambda x: template.format(type=x) + "\nOptions: negative, positive")
|
| 135 |
+
dataset[["context", "target"]] = dataset.apply(format_example, axis=1, result_type="expand")
|
| 136 |
+
|
| 137 |
+
dataloader = DataLoader(Dataset.from_pandas(dataset), batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False)
|
| 138 |
+
|
| 139 |
+
log_interval = len(dataloader) // 5
|
| 140 |
+
|
| 141 |
+
for idx, inputs in enumerate(tqdm(dataloader)):
|
| 142 |
+
inputs = {key: value.to(model.device) for key, value in inputs.items()}
|
| 143 |
+
res = model.generate(**inputs, do_sample=False, max_length=args.max_length, eos_token_id=tokenizer.eos_token_id, max_new_tokens=10)
|
| 144 |
+
res_sentences = [tokenizer.decode(i, skip_special_tokens=True) for i in res]
|
| 145 |
+
tqdm.write(f'{idx}: {res_sentences[0]}')
|
| 146 |
+
# if (idx + 1) % log_interval == 0:
|
| 147 |
+
# tqdm.write(f'{idx}: {res_sentences[0]}')
|
| 148 |
+
out_text = [o.split("Answer: ")[1] for o in res_sentences]
|
| 149 |
+
out_texts_list[i] += out_text
|
| 150 |
+
torch.cuda.empty_cache()
|
| 151 |
+
|
| 152 |
+
for i in range(len(templates)):
|
| 153 |
+
dataset[f"out_text_{i}"] = out_texts_list[i]
|
| 154 |
+
dataset[f"out_text_{i}"] = dataset[f"out_text_{i}"].apply(change_target)
|
| 155 |
+
|
| 156 |
+
dataset["new_out"] = dataset.apply(vote_output, axis=1, result_type="expand")
|
| 157 |
+
dataset.to_csv('tmp.csv')
|
| 158 |
+
|
| 159 |
+
for k in [f"out_text_{i}" for i in range(len(templates))] + ["new_out"]:
|
| 160 |
+
|
| 161 |
+
acc = accuracy_score(dataset["target"], dataset[k])
|
| 162 |
+
f1_macro = f1_score(dataset["target"], dataset[k], average="macro")
|
| 163 |
+
f1_micro = f1_score(dataset["target"], dataset[k], average="micro")
|
| 164 |
+
f1_weighted = f1_score(dataset["target"], dataset[k], average="weighted")
|
| 165 |
+
|
| 166 |
+
print(f"Acc: {acc}. F1 macro: {f1_macro}. F1 micro: {f1_micro}. F1 weighted (BloombergGPT): {f1_weighted}. ")
|
| 167 |
+
|
| 168 |
+
return dataset
|
fingpt/FinGPT_Benchmark/benchmarks/headline.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sklearn.metrics import accuracy_score, f1_score, classification_report
|
| 2 |
+
from datasets import load_dataset, load_from_disk
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import datasets
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
from functools import partial
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from fingpt.FinGPT_Benchmark.utils import *
|
| 10 |
+
|
| 11 |
+
import sys
|
| 12 |
+
sys.path.append('../')
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def binary2multi(dataset):
|
| 17 |
+
pred, label = [], []
|
| 18 |
+
tmp_pred, tmp_label = [], []
|
| 19 |
+
for i, row in dataset.iterrows():
|
| 20 |
+
tmp_pred.append(row['pred'])
|
| 21 |
+
tmp_label.append(row['label'])
|
| 22 |
+
if (i + 1) % 9 == 0:
|
| 23 |
+
pred.append(tmp_pred)
|
| 24 |
+
label.append(tmp_label)
|
| 25 |
+
tmp_pred, tmp_label = [], []
|
| 26 |
+
return pred, label
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def map_output(feature):
|
| 30 |
+
pred = 1 if 'yes' in feature['out_text'].lower() else 0
|
| 31 |
+
label = 1 if 'yes' in feature['output'].lower() else 0
|
| 32 |
+
return {'label': label, 'pred': pred}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def test_headline(args, model, tokenizer):
|
| 36 |
+
|
| 37 |
+
# dataset = load_from_disk('../data/fingpt-headline')['test']
|
| 38 |
+
dataset = load_from_disk(Path(__file__).parent.parent / 'data/fingpt-headline-instruct')['test']
|
| 39 |
+
dataset = dataset.map(partial(test_mapping, args), load_from_cache_file=False)
|
| 40 |
+
|
| 41 |
+
def collate_fn(batch):
|
| 42 |
+
inputs = tokenizer(
|
| 43 |
+
[f["prompt"] for f in batch], return_tensors='pt',
|
| 44 |
+
padding=True, max_length=args.max_length,
|
| 45 |
+
return_token_type_ids=False
|
| 46 |
+
)
|
| 47 |
+
return inputs
|
| 48 |
+
|
| 49 |
+
dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False)
|
| 50 |
+
|
| 51 |
+
out_text_list = []
|
| 52 |
+
log_interval = len(dataloader) // 5
|
| 53 |
+
|
| 54 |
+
for idx, inputs in enumerate(tqdm(dataloader)):
|
| 55 |
+
inputs = {key: value.to(model.device) for key, value in inputs.items()}
|
| 56 |
+
res = model.generate(**inputs, max_length=args.max_length, eos_token_id=tokenizer.eos_token_id)
|
| 57 |
+
res_sentences = [tokenizer.decode(i, skip_special_tokens=True) for i in res]
|
| 58 |
+
tqdm.write(f'{idx}: {res_sentences[0]}')
|
| 59 |
+
if (idx + 1) % log_interval == 0:
|
| 60 |
+
tqdm.write(f'{idx}: {res_sentences[0]}')
|
| 61 |
+
out_text = [o.split("Answer: ")[1] for o in res_sentences]
|
| 62 |
+
out_text_list += out_text
|
| 63 |
+
torch.cuda.empty_cache()
|
| 64 |
+
|
| 65 |
+
dataset = dataset.add_column("out_text", out_text_list)
|
| 66 |
+
dataset = dataset.map(map_output, load_from_cache_file=False)
|
| 67 |
+
dataset = dataset.to_pandas()
|
| 68 |
+
|
| 69 |
+
print(dataset)
|
| 70 |
+
dataset.to_csv('tmp.csv')
|
| 71 |
+
|
| 72 |
+
# binary
|
| 73 |
+
acc = accuracy_score(dataset["label"], dataset["pred"])
|
| 74 |
+
f1 = f1_score(dataset["label"], dataset["pred"], average="binary")
|
| 75 |
+
|
| 76 |
+
# multi-class
|
| 77 |
+
pred, label = binary2multi(dataset)
|
| 78 |
+
|
| 79 |
+
print(f"\n|| Acc: {acc} || F1 binary: {f1} ||\n")
|
| 80 |
+
print(classification_report(label, pred, digits=4, target_names=['price or not', 'price up', 'price stable',
|
| 81 |
+
'price down', 'price past', 'price future',
|
| 82 |
+
'event past', 'event future', 'asset comp']))
|
| 83 |
+
|
| 84 |
+
return dataset
|
fingpt/FinGPT_Benchmark/benchmarks/ner.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from seqeval.metrics import classification_report
|
| 2 |
+
from datasets import load_dataset, load_from_disk
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import datasets
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
from functools import partial
|
| 8 |
+
import re
|
| 9 |
+
import sys
|
| 10 |
+
import numpy as np
|
| 11 |
+
from fingpt.FinGPT_Benchmark.utils import *
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
sys.path.append('../')
|
| 14 |
+
|
| 15 |
+
ent_dict = {
|
| 16 |
+
'PER': 'person',
|
| 17 |
+
'ORG': 'organization',
|
| 18 |
+
'LOC': 'location',
|
| 19 |
+
}
|
| 20 |
+
ent_dict_rev = {v: k for k, v in ent_dict.items()}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def cvt_text_to_pred(tokens, text):
|
| 24 |
+
|
| 25 |
+
preds = ['O' for _ in range(len(tokens))]
|
| 26 |
+
for pred_txt in text.lower().strip('.').split(','):
|
| 27 |
+
|
| 28 |
+
pred_match = re.match(r'^(.*) is an? (.*)$', pred_txt)
|
| 29 |
+
if pred_match is not None:
|
| 30 |
+
entity, entity_type = pred_match.group(1).strip(), pred_match.group(2).strip()
|
| 31 |
+
entity_pred = ent_dict_rev.get(entity_type, 'O')
|
| 32 |
+
entity_tokens = entity.split()
|
| 33 |
+
|
| 34 |
+
n = len(entity_tokens)
|
| 35 |
+
for i in range(len(tokens) - n + 1):
|
| 36 |
+
if tokens[i:i+n] == entity_tokens and preds[i:i+n] == ['O'] * n:
|
| 37 |
+
preds[i:i+n] = ['B-' + entity_pred] + ['I-' + entity_pred] * (n-1)
|
| 38 |
+
break
|
| 39 |
+
else:
|
| 40 |
+
print(pred_txt)
|
| 41 |
+
|
| 42 |
+
return preds
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def map_output(feature):
|
| 46 |
+
|
| 47 |
+
tokens = feature['input'].lower().split()
|
| 48 |
+
label = cvt_text_to_pred(tokens, feature['output'])
|
| 49 |
+
pred = cvt_text_to_pred(tokens, feature['out_text'])
|
| 50 |
+
|
| 51 |
+
return {'label': label, 'pred': pred}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def test_ner(args, model, tokenizer):
|
| 55 |
+
|
| 56 |
+
dataset = load_from_disk(Path(__file__).parent.parent / 'data/fingpt-ner')['test']
|
| 57 |
+
dataset = dataset.map(partial(test_mapping, args), load_from_cache_file=False)
|
| 58 |
+
|
| 59 |
+
def collate_fn(batch):
|
| 60 |
+
inputs = tokenizer(
|
| 61 |
+
[f["prompt"] for f in batch], return_tensors='pt',
|
| 62 |
+
padding=True, max_length=args.max_length,
|
| 63 |
+
return_token_type_ids=False
|
| 64 |
+
)
|
| 65 |
+
return inputs
|
| 66 |
+
|
| 67 |
+
dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False)
|
| 68 |
+
|
| 69 |
+
out_text_list = []
|
| 70 |
+
log_interval = len(dataloader) // 5
|
| 71 |
+
|
| 72 |
+
for idx, inputs in enumerate(tqdm(dataloader)):
|
| 73 |
+
inputs = {key: value.to(model.device) for key, value in inputs.items()}
|
| 74 |
+
res = model.generate(**inputs, max_length=args.max_length, eos_token_id=tokenizer.eos_token_id)
|
| 75 |
+
res_sentences = [tokenizer.decode(i, skip_special_tokens=True) for i in res]
|
| 76 |
+
if (idx + 1) % log_interval == 0:
|
| 77 |
+
tqdm.write(f'{idx}: {res_sentences[0]}')
|
| 78 |
+
out_text = [o.split("Answer: ")[1] for o in res_sentences]
|
| 79 |
+
out_text_list += out_text
|
| 80 |
+
torch.cuda.empty_cache()
|
| 81 |
+
|
| 82 |
+
dataset = dataset.add_column("out_text", out_text_list)
|
| 83 |
+
dataset = dataset.map(map_output, load_from_cache_file=False)
|
| 84 |
+
dataset = dataset.to_pandas()
|
| 85 |
+
|
| 86 |
+
print(dataset)
|
| 87 |
+
dataset.to_csv('tmp.csv')
|
| 88 |
+
|
| 89 |
+
label = [d.tolist() for d in dataset['label']]
|
| 90 |
+
pred = [d.tolist() for d in dataset['pred']]
|
| 91 |
+
|
| 92 |
+
print(classification_report(label, pred, digits=4))
|
| 93 |
+
|
| 94 |
+
return dataset
|
fingpt/FinGPT_Benchmark/benchmarks/nwgi.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
warnings.filterwarnings("ignore")
|
| 3 |
+
|
| 4 |
+
from sklearn.metrics import accuracy_score,f1_score
|
| 5 |
+
from datasets import load_dataset, load_from_disk
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import datasets
|
| 8 |
+
import torch
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
dic = {
|
| 12 |
+
'strong negative':"negative",
|
| 13 |
+
'moderately negative':"negative",
|
| 14 |
+
'mildly negative':"neutral",
|
| 15 |
+
'strong positive':"positive",
|
| 16 |
+
'moderately positive':"positive",
|
| 17 |
+
'mildly positive':'neutral',
|
| 18 |
+
'neutral':'neutral',
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
def format_example(example: dict) -> dict:
|
| 22 |
+
context = f"Instruction: {example['instruction']}\n"
|
| 23 |
+
if example.get("input"):
|
| 24 |
+
context += f"Input: {example['input']}\n"
|
| 25 |
+
context += "Answer: "
|
| 26 |
+
target = example["output"]
|
| 27 |
+
return {"context": context, "target": target}
|
| 28 |
+
|
| 29 |
+
def change_target(x):
|
| 30 |
+
if 'positive' in x or 'Positive' in x:
|
| 31 |
+
return 'positive'
|
| 32 |
+
elif 'negative' in x or 'Negative' in x:
|
| 33 |
+
return 'negative'
|
| 34 |
+
else:
|
| 35 |
+
return 'neutral'
|
| 36 |
+
|
| 37 |
+
def test_nwgi(args, model, tokenizer, prompt_fun=None):
|
| 38 |
+
batch_size = args.batch_size
|
| 39 |
+
# dataset = load_dataset('oliverwang15/news_with_gpt_instructions')
|
| 40 |
+
dataset = load_from_disk(Path(__file__).parent.parent / 'data/news_with_gpt_instructions/')
|
| 41 |
+
dataset['output'] = dataset['label'].apply(lambda x:dic[x])
|
| 42 |
+
|
| 43 |
+
if prompt_fun is None:
|
| 44 |
+
dataset["instruction"] = "What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}."
|
| 45 |
+
# dataset["instruction"] = "What is the sentiment of this news? Please choose an answer from {strong negative/moderately negative/mildly negative/neutral/mildly positive/moderately positive/strong positive}."
|
| 46 |
+
else:
|
| 47 |
+
dataset["instruction"] = dataset.apply(prompt_fun, axis = 1)
|
| 48 |
+
dataset["input"] = dataset["news"]
|
| 49 |
+
|
| 50 |
+
dataset = dataset[['input', 'output', 'instruction']]
|
| 51 |
+
dataset[["context","target"]] = dataset.apply(format_example, axis = 1, result_type="expand")
|
| 52 |
+
|
| 53 |
+
# print example
|
| 54 |
+
print(f"\n\nPrompt example:\n{dataset['context'][0]}\n\n")
|
| 55 |
+
|
| 56 |
+
context = dataset['context'].tolist()
|
| 57 |
+
|
| 58 |
+
total_steps = dataset.shape[0]//batch_size + 1
|
| 59 |
+
print(f"Total len: {len(context)}. Batchsize: {batch_size}. Total steps: {total_steps}")
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
out_text_list = []
|
| 63 |
+
for i in tqdm(range(total_steps)):
|
| 64 |
+
tmp_context = context[i* batch_size:(i+1)* batch_size]
|
| 65 |
+
tokens = tokenizer(tmp_context, return_tensors='pt', padding=True, max_length=512, return_token_type_ids=False)
|
| 66 |
+
# tokens.pop('token_type_ids')
|
| 67 |
+
for k in tokens.keys():
|
| 68 |
+
tokens[k] = tokens[k].cuda()
|
| 69 |
+
res = model.generate(**tokens, max_length=512, eos_token_id=tokenizer.eos_token_id)
|
| 70 |
+
res_sentences = [tokenizer.decode(i, skip_special_tokens=True) for i in res]
|
| 71 |
+
out_text = [o.split("Answer: ")[1] for o in res_sentences]
|
| 72 |
+
out_text_list += out_text
|
| 73 |
+
torch.cuda.empty_cache()
|
| 74 |
+
|
| 75 |
+
dataset["out_text"] = out_text_list
|
| 76 |
+
dataset["new_target"] = dataset["target"].apply(change_target)
|
| 77 |
+
dataset["new_out"] = dataset["out_text"].apply(change_target)
|
| 78 |
+
|
| 79 |
+
acc = accuracy_score(dataset["new_target"], dataset["new_out"])
|
| 80 |
+
f1_macro = f1_score(dataset["new_target"], dataset["new_out"], average = "macro")
|
| 81 |
+
f1_micro = f1_score(dataset["new_target"], dataset["new_out"], average = "micro")
|
| 82 |
+
f1_weighted = f1_score(dataset["new_target"], dataset["new_out"], average = "weighted")
|
| 83 |
+
|
| 84 |
+
print(f"Acc: {acc}. F1 macro: {f1_macro}. F1 micro: {f1_micro}. F1 weighted (BloombergGPT): {f1_weighted}. ")
|
| 85 |
+
|
| 86 |
+
return dataset
|
fingpt/FinGPT_Benchmark/benchmarks/sentiment_templates.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
What is the sentiment of the input {type} from financial perspective?
|
| 2 |
+
Assign a sentiment category to this {type} related to finance.
|
| 3 |
+
Categorize the input {type}'s emotional tone into one of three groups.
|
| 4 |
+
Determine the sentiment expressed in the {type} from financial perspective.
|
| 5 |
+
Characterize the {type}'s sentiment using the following options.
|
fingpt/FinGPT_Benchmark/benchmarks/tfns.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
warnings.filterwarnings("ignore")
|
| 3 |
+
|
| 4 |
+
from sklearn.metrics import accuracy_score,f1_score
|
| 5 |
+
from datasets import load_dataset, load_from_disk
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import datasets
|
| 8 |
+
import torch
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
dic = {
|
| 12 |
+
0:"negative",
|
| 13 |
+
1:'positive',
|
| 14 |
+
2:'neutral',
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
def format_example(example: dict) -> dict:
|
| 18 |
+
context = f"Instruction: {example['instruction']}\n"
|
| 19 |
+
if example.get("input"):
|
| 20 |
+
context += f"Input: {example['input']}\n"
|
| 21 |
+
context += "Answer: "
|
| 22 |
+
target = example["output"]
|
| 23 |
+
return {"context": context, "target": target}
|
| 24 |
+
|
| 25 |
+
def change_target(x):
|
| 26 |
+
if 'positive' in x or 'Positive' in x:
|
| 27 |
+
return 'positive'
|
| 28 |
+
elif 'negative' in x or 'Negative' in x:
|
| 29 |
+
return 'negative'
|
| 30 |
+
else:
|
| 31 |
+
return 'neutral'
|
| 32 |
+
|
| 33 |
+
def test_tfns(args, model, tokenizer, prompt_fun=None):
|
| 34 |
+
batch_size = args.batch_size
|
| 35 |
+
# dataset = load_dataset('zeroshot/twitter-financial-news-sentiment')
|
| 36 |
+
dataset = load_from_disk(Path(__file__).parent.parent / 'data/twitter-financial-news-sentiment')
|
| 37 |
+
dataset = dataset['validation']
|
| 38 |
+
dataset = dataset.to_pandas()
|
| 39 |
+
dataset['label'] = dataset['label'].apply(lambda x:dic[x])
|
| 40 |
+
|
| 41 |
+
if prompt_fun is None:
|
| 42 |
+
dataset["instruction"] = 'What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.'
|
| 43 |
+
else:
|
| 44 |
+
dataset["instruction"] = dataset.apply(prompt_fun, axis = 1)
|
| 45 |
+
|
| 46 |
+
dataset.columns = ['input', 'output', 'instruction']
|
| 47 |
+
dataset[["context","target"]] = dataset.apply(format_example, axis = 1, result_type="expand")
|
| 48 |
+
|
| 49 |
+
# print example
|
| 50 |
+
print(f"\n\nPrompt example:\n{dataset['context'][0]}\n\n")
|
| 51 |
+
|
| 52 |
+
context = dataset['context'].tolist()
|
| 53 |
+
|
| 54 |
+
total_steps = dataset.shape[0]//batch_size + 1
|
| 55 |
+
print(f"Total len: {len(context)}. Batchsize: {batch_size}. Total steps: {total_steps}")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
out_text_list = []
|
| 59 |
+
for i in tqdm(range(total_steps)):
|
| 60 |
+
tmp_context = context[i* batch_size:(i+1)* batch_size]
|
| 61 |
+
tokens = tokenizer(tmp_context, return_tensors='pt', padding=True, max_length=512, return_token_type_ids=False)
|
| 62 |
+
# tokens.pop('token_type_ids')
|
| 63 |
+
for k in tokens.keys():
|
| 64 |
+
tokens[k] = tokens[k].cuda()
|
| 65 |
+
res = model.generate(**tokens, max_length=512, eos_token_id=tokenizer.eos_token_id)
|
| 66 |
+
res_sentences = [tokenizer.decode(i, skip_special_tokens=True) for i in res]
|
| 67 |
+
out_text = [o.split("Answer: ")[1] for o in res_sentences]
|
| 68 |
+
out_text_list += out_text
|
| 69 |
+
torch.cuda.empty_cache()
|
| 70 |
+
|
| 71 |
+
dataset["out_text"] = out_text_list
|
| 72 |
+
dataset["new_target"] = dataset["target"].apply(change_target)
|
| 73 |
+
dataset["new_out"] = dataset["out_text"].apply(change_target)
|
| 74 |
+
|
| 75 |
+
acc = accuracy_score(dataset["new_target"], dataset["new_out"])
|
| 76 |
+
f1_macro = f1_score(dataset["new_target"], dataset["new_out"], average = "macro")
|
| 77 |
+
f1_micro = f1_score(dataset["new_target"], dataset["new_out"], average = "micro")
|
| 78 |
+
f1_weighted = f1_score(dataset["new_target"], dataset["new_out"], average = "weighted")
|
| 79 |
+
|
| 80 |
+
print(f"Acc: {acc}. F1 macro: {f1_macro}. F1 micro: {f1_micro}. F1 weighted (BloombergGPT): {f1_weighted}. ")
|
| 81 |
+
|
| 82 |
+
return dataset
|
fingpt/FinGPT_Benchmark/config.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 3 |
+
"train_batch_size": "auto",
|
| 4 |
+
"gradient_accumulation_steps": "auto",
|
| 5 |
+
"optimizer": {
|
| 6 |
+
"type": "ZeroOneAdam",
|
| 7 |
+
"params": {
|
| 8 |
+
"lr": "auto",
|
| 9 |
+
"weight_decay": "auto",
|
| 10 |
+
"bias_correction": false,
|
| 11 |
+
"var_freeze_step": 1000,
|
| 12 |
+
"var_update_scaler": 16,
|
| 13 |
+
"local_step_scaler": 1000,
|
| 14 |
+
"local_step_clipper": 16,
|
| 15 |
+
"cuda_aware": true,
|
| 16 |
+
"comm_backend_name": "nccl"
|
| 17 |
+
}
|
| 18 |
+
},
|
| 19 |
+
"scheduler": {
|
| 20 |
+
"type": "WarmupLR",
|
| 21 |
+
"params": {
|
| 22 |
+
"warmup_min_lr": 0,
|
| 23 |
+
"warmup_max_lr": "auto",
|
| 24 |
+
"warmup_num_steps": "auto"
|
| 25 |
+
}
|
| 26 |
+
},
|
| 27 |
+
"fp16": {
|
| 28 |
+
"enabled": true
|
| 29 |
+
},
|
| 30 |
+
"zero_optimization": {
|
| 31 |
+
"stage": 0
|
| 32 |
+
}
|
| 33 |
+
}
|
fingpt/FinGPT_Benchmark/config_hf.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 3 |
+
"train_batch_size": "auto",
|
| 4 |
+
"gradient_accumulation_steps": "auto",
|
| 5 |
+
"fp16": {
|
| 6 |
+
"enabled": true
|
| 7 |
+
},
|
| 8 |
+
"zero_optimization": {
|
| 9 |
+
"stage": 0
|
| 10 |
+
}
|
| 11 |
+
}
|
fingpt/FinGPT_Benchmark/config_new.json
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 3 |
+
"train_batch_size": "auto",
|
| 4 |
+
"gradient_accumulation_steps": "auto",
|
| 5 |
+
"optimizer": {
|
| 6 |
+
"type": "AdamW",
|
| 7 |
+
"params": {
|
| 8 |
+
"lr": "auto",
|
| 9 |
+
"weight_decay": "auto",
|
| 10 |
+
"betas": "auto",
|
| 11 |
+
"eps": "auto"
|
| 12 |
+
}
|
| 13 |
+
},
|
| 14 |
+
"scheduler": {
|
| 15 |
+
"type": "WarmupDecayLR",
|
| 16 |
+
"params": {
|
| 17 |
+
"last_batch_iteration": -1,
|
| 18 |
+
"total_num_steps": "auto",
|
| 19 |
+
"warmup_min_lr": "auto",
|
| 20 |
+
"warmup_max_lr": "auto",
|
| 21 |
+
"warmup_num_steps": "auto"
|
| 22 |
+
}
|
| 23 |
+
},
|
| 24 |
+
"fp16": {
|
| 25 |
+
"enabled": true,
|
| 26 |
+
"loss_scale": 0,
|
| 27 |
+
"loss_scale_window": 1000,
|
| 28 |
+
"initial_scale_power": 16,
|
| 29 |
+
"hysteresis": 2,
|
| 30 |
+
"min_loss_scale": 1
|
| 31 |
+
},
|
| 32 |
+
"zero_optimization": {
|
| 33 |
+
"stage": 0
|
| 34 |
+
}
|
| 35 |
+
}
|
fingpt/FinGPT_Benchmark/data/__init__.py
ADDED
|
File without changes
|
fingpt/FinGPT_Benchmark/data/download.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datasets
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import argparse
|
| 4 |
+
|
| 5 |
+
DATASETS = [
|
| 6 |
+
# source, destination
|
| 7 |
+
(('pauri32/fiqa-2018', None), 'fiqa-2018'),
|
| 8 |
+
(('FinGPT/fingpt-finred', None), 'fingpt-finred'),
|
| 9 |
+
(('zeroshot/twitter-financial-news-sentiment', None), 'twitter-financial-news-sentiment'),
|
| 10 |
+
(('oliverwang15/news_with_gpt_instructions', None), 'news_with_gpt_instructions'),
|
| 11 |
+
(('financial_phrasebank', 'sentences_50agree'), 'financial_phrasebank-sentences_50agree'),
|
| 12 |
+
(('FinGPT/fingpt-fiqa_qa', None), 'fingpt-fiqa_qa'),
|
| 13 |
+
(('FinGPT/fingpt-headline-cls', None), 'fingpt-headline-cls'),
|
| 14 |
+
(('FinGPT/fingpt-finred', None), 'fingpt-finred'),
|
| 15 |
+
(('FinGPT/fingpt-convfinqa', None), 'fingpt-convfinqa'),
|
| 16 |
+
(('FinGPT/fingpt-finred-cls', None), 'fingpt-finred-cls'),
|
| 17 |
+
(('FinGPT/fingpt-ner', None), 'fingpt-ner'),
|
| 18 |
+
(('FinGPT/fingpt-headline', None), 'fingpt-headline-instruct'),
|
| 19 |
+
(('FinGPT/fingpt-finred-re', None), 'fingpt-finred-re'),
|
| 20 |
+
(('FinGPT/fingpt-ner-cls', None), 'fingpt-ner-cls'),
|
| 21 |
+
(('FinGPT/fingpt-fineval', None), 'fingpt-fineval'),
|
| 22 |
+
(('FinGPT/fingpt-sentiment-cls', None), 'fingpt-sentiment-cls'),
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
def download(no_cache: bool = False):
|
| 26 |
+
"""Downloads all datasets to where the FinGPT library is located."""
|
| 27 |
+
data_dir = Path(__file__).parent
|
| 28 |
+
|
| 29 |
+
for src, dest in DATASETS:
|
| 30 |
+
if Path(data_dir / dest).is_dir() and not no_cache:
|
| 31 |
+
print(f"Dataset found at {data_dir / dest}, skipping")
|
| 32 |
+
continue
|
| 33 |
+
dataset = datasets.load_dataset(*src)
|
| 34 |
+
dataset.save_to_disk(data_dir / dest)
|
| 35 |
+
|
| 36 |
+
if __name__ == "__main__":
|
| 37 |
+
parser = argparse.ArgumentParser()
|
| 38 |
+
parser.add_argument("--no_cache", default=False, required=False, type=str, help="Redownloads all datasets if set to True")
|
| 39 |
+
|
| 40 |
+
args = parser.parse_args()
|
| 41 |
+
download(no_cache=args.no_cache)
|
fingpt/FinGPT_Benchmark/data/prepare_data.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
fingpt/FinGPT_Benchmark/demo.ipynb
ADDED
|
@@ -0,0 +1,715 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# Read before you start:\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"This notebook gives a test demo for all the LLMs we trained during phase2: Multi-Task Instruction Tuning.\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"- LLMs: Llama2, Falcon, BLOOM, ChatGLM2, Qwen, MPT\n",
|
| 12 |
+
"- Tasks: Sentiment Analysis, Headline Classification, Named Entity Extraction, Relation Extraction\n",
|
| 13 |
+
"\n",
|
| 14 |
+
"All the models & instruction data samples used are also available in our huggingface organization. [https://huggingface.co/FinGPT]\n",
|
| 15 |
+
"\n",
|
| 16 |
+
"Models trained in phase1&3 are not provided, as MT-models cover most of their capacity. Feel free to train your own models based on the tasks you want.\n",
|
| 17 |
+
"\n",
|
| 18 |
+
"Due to the limited diversity of the financial tasks and datasets we used, models might not response correctly to out-of-scope instructions. We'll delve into the generalization ability more in our future works."
|
| 19 |
+
]
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
"cell_type": "code",
|
| 23 |
+
"execution_count": 1,
|
| 24 |
+
"metadata": {},
|
| 25 |
+
"outputs": [],
|
| 26 |
+
"source": [
|
| 27 |
+
"# First choose to load data/model from huggingface or local space\n",
|
| 28 |
+
"\n",
|
| 29 |
+
"FROM_REMOTE = False"
|
| 30 |
+
]
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"cell_type": "code",
|
| 34 |
+
"execution_count": 2,
|
| 35 |
+
"metadata": {},
|
| 36 |
+
"outputs": [
|
| 37 |
+
{
|
| 38 |
+
"name": "stdout",
|
| 39 |
+
"output_type": "stream",
|
| 40 |
+
"text": [
|
| 41 |
+
"[2023-10-15 20:44:54,084] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
|
| 42 |
+
]
|
| 43 |
+
}
|
| 44 |
+
],
|
| 45 |
+
"source": [
|
| 46 |
+
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
|
| 47 |
+
"from peft import PeftModel\n",
|
| 48 |
+
"from utils import *"
|
| 49 |
+
]
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"cell_type": "code",
|
| 53 |
+
"execution_count": 3,
|
| 54 |
+
"metadata": {},
|
| 55 |
+
"outputs": [],
|
| 56 |
+
"source": [
|
| 57 |
+
"import logging\n",
|
| 58 |
+
"# Suppress Warnings during inference\n",
|
| 59 |
+
"logging.getLogger(\"transformers\").setLevel(logging.ERROR)"
|
| 60 |
+
]
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"cell_type": "code",
|
| 64 |
+
"execution_count": 4,
|
| 65 |
+
"metadata": {},
|
| 66 |
+
"outputs": [],
|
| 67 |
+
"source": [
|
| 68 |
+
"demo_tasks = [\n",
|
| 69 |
+
" 'Financial Sentiment Analysis',\n",
|
| 70 |
+
" 'Financial Relation Extraction',\n",
|
| 71 |
+
" 'Financial Headline Classification',\n",
|
| 72 |
+
" 'Financial Named Entity Recognition',\n",
|
| 73 |
+
"]\n",
|
| 74 |
+
"demo_inputs = [\n",
|
| 75 |
+
" \"Glaxo's ViiV Healthcare Signs China Manufacturing Deal With Desano\",\n",
|
| 76 |
+
" \"Wednesday, July 8, 2015 10:30AM IST (5:00AM GMT) Rimini Street Comment on Oracle Litigation Las Vegas, United States Rimini Street, Inc., the leading independent provider of enterprise software support for SAP AG’s (NYSE:SAP) Business Suite and BusinessObjects software and Oracle Corporation’s (NYSE:ORCL) Siebel , PeopleSoft , JD Edwards , E-Business Suite , Oracle Database , Hyperion and Oracle Retail software, today issued a statement on the Oracle litigation.\",\n",
|
| 77 |
+
" 'april gold down 20 cents to settle at $1,116.10/oz',\n",
|
| 78 |
+
" 'Subject to the terms and conditions of this Agreement , Bank agrees to lend to Borrower , from time to time prior to the Commitment Termination Date , equipment advances ( each an \" Equipment Advance \" and collectively the \" Equipment Advances \").',\n",
|
| 79 |
+
"]\n",
|
| 80 |
+
"demo_instructions = [\n",
|
| 81 |
+
" 'What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.',\n",
|
| 82 |
+
" 'Given phrases that describe the relationship between two words/phrases as options, extract the word/phrase pair and the corresponding lexical relationship between them from the input text. The output format should be \"relation1: word1, word2; relation2: word3, word4\". Options: product/material produced, manufacturer, distributed by, industry, position held, original broadcaster, owned by, founded by, distribution format, headquarters location, stock exchange, currency, parent organization, chief executive officer, director/manager, owner of, operator, member of, employer, chairperson, platform, subsidiary, legal form, publisher, developer, brand, business division, location of formation, creator.',\n",
|
| 83 |
+
" 'Does the news headline talk about price in the past? Please choose an answer from {Yes/No}.',\n",
|
| 84 |
+
" 'Please extract entities and their types from the input sentence, entity types should be chosen from {person/organization/location}.',\n",
|
| 85 |
+
"]"
|
| 86 |
+
]
|
| 87 |
+
},
|
| 88 |
+
{
|
| 89 |
+
"cell_type": "code",
|
| 90 |
+
"execution_count": 5,
|
| 91 |
+
"metadata": {},
|
| 92 |
+
"outputs": [],
|
| 93 |
+
"source": [
|
| 94 |
+
"def load_model(base_model, peft_model, from_remote=False):\n",
|
| 95 |
+
" \n",
|
| 96 |
+
" model_name = parse_model_name(base_model, from_remote)\n",
|
| 97 |
+
"\n",
|
| 98 |
+
" model = AutoModelForCausalLM.from_pretrained(\n",
|
| 99 |
+
" model_name, trust_remote_code=True, \n",
|
| 100 |
+
" device_map=\"auto\",\n",
|
| 101 |
+
" )\n",
|
| 102 |
+
" model.model_parallel = True\n",
|
| 103 |
+
"\n",
|
| 104 |
+
" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
|
| 105 |
+
" \n",
|
| 106 |
+
" tokenizer.padding_side = \"left\"\n",
|
| 107 |
+
" if base_model == 'qwen':\n",
|
| 108 |
+
" tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids('<|endoftext|>')\n",
|
| 109 |
+
" tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids('<|extra_0|>')\n",
|
| 110 |
+
" if not tokenizer.pad_token or tokenizer.pad_token_id == tokenizer.eos_token_id:\n",
|
| 111 |
+
" tokenizer.add_special_tokens({'pad_token': '[PAD]'})\n",
|
| 112 |
+
" model.resize_token_embeddings(len(tokenizer))\n",
|
| 113 |
+
" \n",
|
| 114 |
+
" model = PeftModel.from_pretrained(model, peft_model)\n",
|
| 115 |
+
" model = model.eval()\n",
|
| 116 |
+
" return model, tokenizer\n",
|
| 117 |
+
"\n",
|
| 118 |
+
"\n",
|
| 119 |
+
"def test_demo(model, tokenizer):\n",
|
| 120 |
+
"\n",
|
| 121 |
+
" for task_name, input, instruction in zip(demo_tasks, demo_inputs, demo_instructions):\n",
|
| 122 |
+
" prompt = 'Instruction: {instruction}\\nInput: {input}\\nAnswer: '.format(\n",
|
| 123 |
+
" input=input, \n",
|
| 124 |
+
" instruction=instruction\n",
|
| 125 |
+
" )\n",
|
| 126 |
+
" inputs = tokenizer(\n",
|
| 127 |
+
" prompt, return_tensors='pt',\n",
|
| 128 |
+
" padding=True, max_length=512,\n",
|
| 129 |
+
" return_token_type_ids=False\n",
|
| 130 |
+
" )\n",
|
| 131 |
+
" inputs = {key: value.to(model.device) for key, value in inputs.items()}\n",
|
| 132 |
+
" res = model.generate(\n",
|
| 133 |
+
" **inputs, max_length=512, do_sample=False,\n",
|
| 134 |
+
" eos_token_id=tokenizer.eos_token_id\n",
|
| 135 |
+
" )\n",
|
| 136 |
+
" output = tokenizer.decode(res[0], skip_special_tokens=True)\n",
|
| 137 |
+
" print(f\"\\n==== {task_name} ====\\n\")\n",
|
| 138 |
+
" print(output)\n",
|
| 139 |
+
" "
|
| 140 |
+
]
|
| 141 |
+
},
|
| 142 |
+
{
|
| 143 |
+
"cell_type": "markdown",
|
| 144 |
+
"metadata": {},
|
| 145 |
+
"source": [
|
| 146 |
+
"# Llama2-7B"
|
| 147 |
+
]
|
| 148 |
+
},
|
| 149 |
+
{
|
| 150 |
+
"cell_type": "code",
|
| 151 |
+
"execution_count": 6,
|
| 152 |
+
"metadata": {},
|
| 153 |
+
"outputs": [
|
| 154 |
+
{
|
| 155 |
+
"data": {
|
| 156 |
+
"application/json": {
|
| 157 |
+
"ascii": false,
|
| 158 |
+
"bar_format": null,
|
| 159 |
+
"colour": null,
|
| 160 |
+
"elapsed": 0.006228446960449219,
|
| 161 |
+
"initial": 0,
|
| 162 |
+
"n": 0,
|
| 163 |
+
"ncols": null,
|
| 164 |
+
"nrows": null,
|
| 165 |
+
"postfix": null,
|
| 166 |
+
"prefix": "Loading checkpoint shards",
|
| 167 |
+
"rate": null,
|
| 168 |
+
"total": 2,
|
| 169 |
+
"unit": "it",
|
| 170 |
+
"unit_divisor": 1000,
|
| 171 |
+
"unit_scale": false
|
| 172 |
+
},
|
| 173 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 174 |
+
"model_id": "0d58aff745fb486780792c86384fe702",
|
| 175 |
+
"version_major": 2,
|
| 176 |
+
"version_minor": 0
|
| 177 |
+
},
|
| 178 |
+
"text/plain": [
|
| 179 |
+
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
| 180 |
+
]
|
| 181 |
+
},
|
| 182 |
+
"metadata": {},
|
| 183 |
+
"output_type": "display_data"
|
| 184 |
+
},
|
| 185 |
+
{
|
| 186 |
+
"name": "stderr",
|
| 187 |
+
"output_type": "stream",
|
| 188 |
+
"text": [
|
| 189 |
+
"Using pad_token, but it is not set yet.\n",
|
| 190 |
+
"/root/.conda/envs/torch2/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:2436: UserWarning: `max_length` is ignored when `padding`=`True` and there is no truncation strategy. To pad to max length, use `padding='max_length'`.\n",
|
| 191 |
+
" warnings.warn(\n",
|
| 192 |
+
"/root/.conda/envs/torch2/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:362: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
|
| 193 |
+
" warnings.warn(\n",
|
| 194 |
+
"/root/.conda/envs/torch2/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:367: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n",
|
| 195 |
+
" warnings.warn(\n"
|
| 196 |
+
]
|
| 197 |
+
},
|
| 198 |
+
{
|
| 199 |
+
"name": "stdout",
|
| 200 |
+
"output_type": "stream",
|
| 201 |
+
"text": [
|
| 202 |
+
"\n",
|
| 203 |
+
"==== Financial Sentiment Analysis ====\n",
|
| 204 |
+
"\n",
|
| 205 |
+
"Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
|
| 206 |
+
"Input: Glaxo's ViiV Healthcare Signs China Manufacturing Deal With Desano\n",
|
| 207 |
+
"Answer: positive\n",
|
| 208 |
+
"\n",
|
| 209 |
+
"==== Financial Relation Extraction ====\n",
|
| 210 |
+
"\n",
|
| 211 |
+
"Instruction: Given phrases that describe the relationship between two words/phrases as options, extract the word/phrase pair and the corresponding lexical relationship between them from the input text. The output format should be \"relation1: word1, word2; relation2: word3, word4\". Options: product/material produced, manufacturer, distributed by, industry, position held, original broadcaster, owned by, founded by, distribution format, headquarters location, stock exchange, currency, parent organization, chief executive officer, director/manager, owner of, operator, member of, employer, chairperson, platform, subsidiary, legal form, publisher, developer, brand, business division, location of formation, creator.\n",
|
| 212 |
+
"Input: Wednesday, July 8, 2015 10:30AM IST (5:00AM GMT) Rimini Street Comment on Oracle Litigation Las Vegas, United States Rimini Street, Inc., the leading independent provider of enterprise software support for SAP AG’s (NYSE:SAP) Business Suite and BusinessObjects software and Oracle Corporation’s (NYSE:ORCL) Siebel , PeopleSoft , JD Edwards , E-Business Suite , Oracle Database , Hyperion and Oracle Retail software, today issued a statement on the Oracle litigation.\n",
|
| 213 |
+
"Answer: product_or_material_produced: PeopleSoft, software; parent_organization: Siebel, Oracle Corporation; industry: Oracle Corporation, software; product_or_material_produced: Oracle Corporation, software; product_or_material_produced: Oracle Corporation, software\n",
|
| 214 |
+
"\n",
|
| 215 |
+
"==== Financial Headline Classification ====\n",
|
| 216 |
+
"\n",
|
| 217 |
+
"Instruction: Does the news headline talk about price in the past? Please choose an answer from {Yes/No}.\n",
|
| 218 |
+
"Input: april gold down 20 cents to settle at $1,116.10/oz\n",
|
| 219 |
+
"Answer: Yes\n",
|
| 220 |
+
"\n",
|
| 221 |
+
"==== Financial Named Entity Recognition ====\n",
|
| 222 |
+
"\n",
|
| 223 |
+
"Instruction: Please extract entities and their types from the input sentence, entity types should be chosen from {person/organization/location}.\n",
|
| 224 |
+
"Input: Subject to the terms and conditions of this Agreement , Bank agrees to lend to Borrower , from time to time prior to the Commitment Termination Date , equipment advances ( each an \" Equipment Advance \" and collectively the \" Equipment Advances \").\n",
|
| 225 |
+
"Answer: Bank is an organization, Borrower is a person.\n"
|
| 226 |
+
]
|
| 227 |
+
}
|
| 228 |
+
],
|
| 229 |
+
"source": [
|
| 230 |
+
"base_model = 'llama2'\n",
|
| 231 |
+
"peft_model = 'FinGPT/fingpt-mt_llama2-7b_lora' if FROM_REMOTE else 'finetuned_models/MT-llama2-linear_202309241345'\n",
|
| 232 |
+
"\n",
|
| 233 |
+
"model, tokenizer = load_model(base_model, peft_model, FROM_REMOTE)\n",
|
| 234 |
+
"test_demo(model, tokenizer)"
|
| 235 |
+
]
|
| 236 |
+
},
|
| 237 |
+
{
|
| 238 |
+
"cell_type": "markdown",
|
| 239 |
+
"metadata": {},
|
| 240 |
+
"source": [
|
| 241 |
+
"# Qwen-7B"
|
| 242 |
+
]
|
| 243 |
+
},
|
| 244 |
+
{
|
| 245 |
+
"cell_type": "code",
|
| 246 |
+
"execution_count": 7,
|
| 247 |
+
"metadata": {},
|
| 248 |
+
"outputs": [
|
| 249 |
+
{
|
| 250 |
+
"name": "stderr",
|
| 251 |
+
"output_type": "stream",
|
| 252 |
+
"text": [
|
| 253 |
+
"The model is automatically converting to bf16 for faster inference. If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\".\n",
|
| 254 |
+
"Try importing flash-attention for faster inference...\n",
|
| 255 |
+
"Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary\n",
|
| 256 |
+
"Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm\n",
|
| 257 |
+
"Warning: import flash_attn fail, please install FlashAttention to get higher efficiency https://github.com/Dao-AILab/flash-attention\n"
|
| 258 |
+
]
|
| 259 |
+
},
|
| 260 |
+
{
|
| 261 |
+
"data": {
|
| 262 |
+
"application/json": {
|
| 263 |
+
"ascii": false,
|
| 264 |
+
"bar_format": null,
|
| 265 |
+
"colour": null,
|
| 266 |
+
"elapsed": 0.004647493362426758,
|
| 267 |
+
"initial": 0,
|
| 268 |
+
"n": 0,
|
| 269 |
+
"ncols": null,
|
| 270 |
+
"nrows": null,
|
| 271 |
+
"postfix": null,
|
| 272 |
+
"prefix": "Loading checkpoint shards",
|
| 273 |
+
"rate": null,
|
| 274 |
+
"total": 8,
|
| 275 |
+
"unit": "it",
|
| 276 |
+
"unit_divisor": 1000,
|
| 277 |
+
"unit_scale": false
|
| 278 |
+
},
|
| 279 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 280 |
+
"model_id": "e1978e69ea784778acd1813cc0647c3e",
|
| 281 |
+
"version_major": 2,
|
| 282 |
+
"version_minor": 0
|
| 283 |
+
},
|
| 284 |
+
"text/plain": [
|
| 285 |
+
"Loading checkpoint shards: 0%| | 0/8 [00:00<?, ?it/s]"
|
| 286 |
+
]
|
| 287 |
+
},
|
| 288 |
+
"metadata": {},
|
| 289 |
+
"output_type": "display_data"
|
| 290 |
+
},
|
| 291 |
+
{
|
| 292 |
+
"name": "stderr",
|
| 293 |
+
"output_type": "stream",
|
| 294 |
+
"text": [
|
| 295 |
+
"/root/.conda/envs/torch2/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:367: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.8` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n",
|
| 296 |
+
" warnings.warn(\n",
|
| 297 |
+
"/root/.conda/envs/torch2/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:377: UserWarning: `do_sample` is set to `False`. However, `top_k` is set to `0` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_k`.\n",
|
| 298 |
+
" warnings.warn(\n"
|
| 299 |
+
]
|
| 300 |
+
},
|
| 301 |
+
{
|
| 302 |
+
"name": "stdout",
|
| 303 |
+
"output_type": "stream",
|
| 304 |
+
"text": [
|
| 305 |
+
"\n",
|
| 306 |
+
"==== Financial Sentiment Analysis ====\n",
|
| 307 |
+
"\n",
|
| 308 |
+
"Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
|
| 309 |
+
"Input: Glaxo's ViiV Healthcare Signs China Manufacturing Deal With Desano\n",
|
| 310 |
+
"Answer: positive\n",
|
| 311 |
+
"\n",
|
| 312 |
+
"==== Financial Relation Extraction ====\n",
|
| 313 |
+
"\n",
|
| 314 |
+
"Instruction: Given phrases that describe the relationship between two words/phrases as options, extract the word/phrase pair and the corresponding lexical relationship between them from the input text. The output format should be \"relation1: word1, word2; relation2: word3, word4\". Options: product/material produced, manufacturer, distributed by, industry, position held, original broadcaster, owned by, founded by, distribution format, headquarters location, stock exchange, currency, parent organization, chief executive officer, director/manager, owner of, operator, member of, employer, chairperson, platform, subsidiary, legal form, publisher, developer, brand, business division, location of formation, creator.\n",
|
| 315 |
+
"Input: Wednesday, July 8, 2015 10:30AM IST (5:00AM GMT) Rimini Street Comment on Oracle Litigation Las Vegas, United States Rimini Street, Inc., the leading independent provider of enterprise software support for SAP AG’s (NYSE:SAP) Business Suite and BusinessObjects software and Oracle Corporation’s (NYSE:ORCL) Siebel , PeopleSoft , JD Edwards , E-Business Suite , Oracle Database , Hyperion and Oracle Retail software, today issued a statement on the Oracle litigation.\n",
|
| 316 |
+
"Answer: subsidiary: PeopleSoft, JD Edwards\n",
|
| 317 |
+
"\n",
|
| 318 |
+
"==== Financial Headline Classification ====\n",
|
| 319 |
+
"\n",
|
| 320 |
+
"Instruction: Does the news headline talk about price in the past? Please choose an answer from {Yes/No}.\n",
|
| 321 |
+
"Input: april gold down 20 cents to settle at $1,116.10/oz\n",
|
| 322 |
+
"Answer: Yes\n",
|
| 323 |
+
"\n",
|
| 324 |
+
"==== Financial Named Entity Recognition ====\n",
|
| 325 |
+
"\n",
|
| 326 |
+
"Instruction: Please extract entities and their types from the input sentence, entity types should be chosen from {person/organization/location}.\n",
|
| 327 |
+
"Input: Subject to the terms and conditions of this Agreement , Bank agrees to lend to Borrower , from time to time prior to the Commitment Termination Date , equipment advances ( each an \" Equipment Advance \" and collectively the \" Equipment Advances \").\n",
|
| 328 |
+
"Answer: Bank is an organization, Borrower is a person.\n"
|
| 329 |
+
]
|
| 330 |
+
}
|
| 331 |
+
],
|
| 332 |
+
"source": [
|
| 333 |
+
"base_model = 'qwen'\n",
|
| 334 |
+
"peft_model = 'FinGPT/fingpt-mt_qwen-7b_lora' if FROM_REMOTE else 'finetuned_models/MT-qwen-linear_202309221011'\n",
|
| 335 |
+
"\n",
|
| 336 |
+
"model, tokenizer = load_model(base_model, peft_model, FROM_REMOTE)\n",
|
| 337 |
+
"test_demo(model, tokenizer)"
|
| 338 |
+
]
|
| 339 |
+
},
|
| 340 |
+
{
|
| 341 |
+
"cell_type": "markdown",
|
| 342 |
+
"metadata": {},
|
| 343 |
+
"source": [
|
| 344 |
+
"# Falcon-7B"
|
| 345 |
+
]
|
| 346 |
+
},
|
| 347 |
+
{
|
| 348 |
+
"cell_type": "code",
|
| 349 |
+
"execution_count": 8,
|
| 350 |
+
"metadata": {},
|
| 351 |
+
"outputs": [
|
| 352 |
+
{
|
| 353 |
+
"data": {
|
| 354 |
+
"application/json": {
|
| 355 |
+
"ascii": false,
|
| 356 |
+
"bar_format": null,
|
| 357 |
+
"colour": null,
|
| 358 |
+
"elapsed": 0.004422426223754883,
|
| 359 |
+
"initial": 0,
|
| 360 |
+
"n": 0,
|
| 361 |
+
"ncols": null,
|
| 362 |
+
"nrows": null,
|
| 363 |
+
"postfix": null,
|
| 364 |
+
"prefix": "Loading checkpoint shards",
|
| 365 |
+
"rate": null,
|
| 366 |
+
"total": 2,
|
| 367 |
+
"unit": "it",
|
| 368 |
+
"unit_divisor": 1000,
|
| 369 |
+
"unit_scale": false
|
| 370 |
+
},
|
| 371 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 372 |
+
"model_id": "e12fadfbaa6048538bbeef26ed563b28",
|
| 373 |
+
"version_major": 2,
|
| 374 |
+
"version_minor": 0
|
| 375 |
+
},
|
| 376 |
+
"text/plain": [
|
| 377 |
+
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
| 378 |
+
]
|
| 379 |
+
},
|
| 380 |
+
"metadata": {},
|
| 381 |
+
"output_type": "display_data"
|
| 382 |
+
},
|
| 383 |
+
{
|
| 384 |
+
"name": "stderr",
|
| 385 |
+
"output_type": "stream",
|
| 386 |
+
"text": [
|
| 387 |
+
"Using pad_token, but it is not set yet.\n",
|
| 388 |
+
"/root/.conda/envs/torch2/lib/python3.9/site-packages/transformers/generation/utils.py:1411: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use a generation configuration file (see https://huggingface.co/docs/transformers/main_classes/text_generation )\n",
|
| 389 |
+
" warnings.warn(\n"
|
| 390 |
+
]
|
| 391 |
+
},
|
| 392 |
+
{
|
| 393 |
+
"name": "stdout",
|
| 394 |
+
"output_type": "stream",
|
| 395 |
+
"text": [
|
| 396 |
+
"\n",
|
| 397 |
+
"==== Financial Sentiment Analysis ====\n",
|
| 398 |
+
"\n",
|
| 399 |
+
"Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
|
| 400 |
+
"Input: Glaxo's ViiV Healthcare Signs China Manufacturing Deal With Desano\n",
|
| 401 |
+
"Answer: positive\n",
|
| 402 |
+
"\n",
|
| 403 |
+
"==== Financial Relation Extraction ====\n",
|
| 404 |
+
"\n",
|
| 405 |
+
"Instruction: Given phrases that describe the relationship between two words/phrases as options, extract the word/phrase pair and the corresponding lexical relationship between them from the input text. The output format should be \"relation1: word1, word2; relation2: word3, word4\". Options: product/material produced, manufacturer, distributed by, industry, position held, original broadcaster, owned by, founded by, distribution format, headquarters location, stock exchange, currency, parent organization, chief executive officer, director/manager, owner of, operator, member of, employer, chairperson, platform, subsidiary, legal form, publisher, developer, brand, business division, location of formation, creator.\n",
|
| 406 |
+
"Input: Wednesday, July 8, 2015 10:30AM IST (5:00AM GMT) Rimini Street Comment on Oracle Litigation Las Vegas, United States Rimini Street, Inc., the leading independent provider of enterprise software support for SAP AG’s (NYSE:SAP) Business Suite and BusinessObjects software and Oracle Corporation’s (NYSE:ORCL) Siebel, PeopleSoft, JD Edwards, E-Business Suite, Oracle Database, Hyperion and Oracle Retail software, today issued a statement on the Oracle litigation.\n",
|
| 407 |
+
"Answer: product_or_material_produced: PeopleSoft, Oracle Database\n",
|
| 408 |
+
"\n",
|
| 409 |
+
"==== Financial Headline Classification ====\n",
|
| 410 |
+
"\n",
|
| 411 |
+
"Instruction: Does the news headline talk about price in the past? Please choose an answer from {Yes/No}.\n",
|
| 412 |
+
"Input: april gold down 20 cents to settle at $1,116.10/oz\n",
|
| 413 |
+
"Answer: Yes\n",
|
| 414 |
+
"\n",
|
| 415 |
+
"==== Financial Named Entity Recognition ====\n",
|
| 416 |
+
"\n",
|
| 417 |
+
"Instruction: Please extract entities and their types from the input sentence, entity types should be chosen from {person/organization/location}.\n",
|
| 418 |
+
"Input: Subject to the terms and conditions of this Agreement, Bank agrees to lend to Borrower, from time to time prior to the Commitment Termination Date, equipment advances ( each an \" Equipment Advance \" and collectively the \" Equipment Advances \").\n",
|
| 419 |
+
"Answer: Bank is an organization, Borrower is a person.\n"
|
| 420 |
+
]
|
| 421 |
+
}
|
| 422 |
+
],
|
| 423 |
+
"source": [
|
| 424 |
+
"base_model = 'falcon'\n",
|
| 425 |
+
"peft_model = 'FinGPT/fingpt-mt_falcon-7b_lora' if FROM_REMOTE else 'finetuned_models/MT-falcon-linear_202309210126'\n",
|
| 426 |
+
"\n",
|
| 427 |
+
"model, tokenizer = load_model(base_model, peft_model, FROM_REMOTE)\n",
|
| 428 |
+
"test_demo(model, tokenizer)"
|
| 429 |
+
]
|
| 430 |
+
},
|
| 431 |
+
{
|
| 432 |
+
"cell_type": "markdown",
|
| 433 |
+
"metadata": {},
|
| 434 |
+
"source": [
|
| 435 |
+
"# ChatGLM2-6B"
|
| 436 |
+
]
|
| 437 |
+
},
|
| 438 |
+
{
|
| 439 |
+
"cell_type": "code",
|
| 440 |
+
"execution_count": 9,
|
| 441 |
+
"metadata": {},
|
| 442 |
+
"outputs": [
|
| 443 |
+
{
|
| 444 |
+
"data": {
|
| 445 |
+
"application/json": {
|
| 446 |
+
"ascii": false,
|
| 447 |
+
"bar_format": null,
|
| 448 |
+
"colour": null,
|
| 449 |
+
"elapsed": 0.004460573196411133,
|
| 450 |
+
"initial": 0,
|
| 451 |
+
"n": 0,
|
| 452 |
+
"ncols": null,
|
| 453 |
+
"nrows": null,
|
| 454 |
+
"postfix": null,
|
| 455 |
+
"prefix": "Loading checkpoint shards",
|
| 456 |
+
"rate": null,
|
| 457 |
+
"total": 7,
|
| 458 |
+
"unit": "it",
|
| 459 |
+
"unit_divisor": 1000,
|
| 460 |
+
"unit_scale": false
|
| 461 |
+
},
|
| 462 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 463 |
+
"model_id": "8bddd025a6514946b5f07f55e9c38f58",
|
| 464 |
+
"version_major": 2,
|
| 465 |
+
"version_minor": 0
|
| 466 |
+
},
|
| 467 |
+
"text/plain": [
|
| 468 |
+
"Loading checkpoint shards: 0%| | 0/7 [00:00<?, ?it/s]"
|
| 469 |
+
]
|
| 470 |
+
},
|
| 471 |
+
"metadata": {},
|
| 472 |
+
"output_type": "display_data"
|
| 473 |
+
},
|
| 474 |
+
{
|
| 475 |
+
"name": "stdout",
|
| 476 |
+
"output_type": "stream",
|
| 477 |
+
"text": [
|
| 478 |
+
"\n",
|
| 479 |
+
"==== Financial Sentiment Analysis ====\n",
|
| 480 |
+
"\n",
|
| 481 |
+
"Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
|
| 482 |
+
"Input: Glaxo's ViiV Healthcare Signs China Manufacturing Deal With Desano\n",
|
| 483 |
+
"Answer: positive\n",
|
| 484 |
+
"\n",
|
| 485 |
+
"==== Financial Relation Extraction ====\n",
|
| 486 |
+
"\n",
|
| 487 |
+
"Instruction: Given phrases that describe the relationship between two words/phrases as options, extract the word/phrase pair and the corresponding lexical relationship between them from the input text. The output format should be \"relation1: word1, word2; relation2: word3, word4\". Options: product/material produced, manufacturer, distributed by, industry, position held, original broadcaster, owned by, founded by, distribution format, headquarters location, stock exchange, currency, parent organization, chief executive officer, director/manager, owner of, operator, member of, employer, chairperson, platform, subsidiary, legal form, publisher, developer, brand, business division, location of formation, creator.\n",
|
| 488 |
+
"Input: Wednesday, July 8, 2015 10:30AM IST (5:00AM GMT) Rimini Street Comment on Oracle Litigation Las Vegas, United States Rimini Street, Inc., the leading independent provider of enterprise software support for SAP AG’s (NYSE:SAP) Business Suite and BusinessObjects software and Oracle Corporation’s (NYSE:ORCL) Siebel , PeopleSoft , JD Edwards , E-Business Suite , Oracle Database , Hyperion and Oracle Retail software, today issued a statement on the Oracle litigation.\n",
|
| 489 |
+
"Answer: product_or_material_produced: Oracle, Oracle Database; developer: Oracle, Oracle; product_or_material_produced: Oracle, Oracle Database\n",
|
| 490 |
+
"\n",
|
| 491 |
+
"==== Financial Headline Classification ====\n",
|
| 492 |
+
"\n",
|
| 493 |
+
"Instruction: Does the news headline talk about price in the past? Please choose an answer from {Yes/No}.\n",
|
| 494 |
+
"Input: april gold down 20 cents to settle at $1,116.10/oz\n",
|
| 495 |
+
"Answer: Yes\n",
|
| 496 |
+
"\n",
|
| 497 |
+
"==== Financial Named Entity Recognition ====\n",
|
| 498 |
+
"\n",
|
| 499 |
+
"Instruction: Please extract entities and their types from the input sentence, entity types should be chosen from {person/organization/location}.\n",
|
| 500 |
+
"Input: Subject to the terms and conditions of this Agreement , Bank agrees to lend to Borrower , from time to time prior to the Commitment Termination Date , equipment advances ( each an \" Equipment Advance \" and collectively the \" Equipment Advances \").\n",
|
| 501 |
+
"Answer: Bank is an organization, Borrower is a person.\n"
|
| 502 |
+
]
|
| 503 |
+
}
|
| 504 |
+
],
|
| 505 |
+
"source": [
|
| 506 |
+
"base_model = 'chatglm2'\n",
|
| 507 |
+
"peft_model = 'FinGPT/fingpt-mt_chatglm2-6b_lora' if FROM_REMOTE else 'finetuned_models/MT-chatglm2-linear_202309201120'\n",
|
| 508 |
+
"\n",
|
| 509 |
+
"model, tokenizer = load_model(base_model, peft_model, FROM_REMOTE)\n",
|
| 510 |
+
"test_demo(model, tokenizer)"
|
| 511 |
+
]
|
| 512 |
+
},
|
| 513 |
+
{
|
| 514 |
+
"cell_type": "markdown",
|
| 515 |
+
"metadata": {},
|
| 516 |
+
"source": [
|
| 517 |
+
"# BLOOM-7B1"
|
| 518 |
+
]
|
| 519 |
+
},
|
| 520 |
+
{
|
| 521 |
+
"cell_type": "code",
|
| 522 |
+
"execution_count": 10,
|
| 523 |
+
"metadata": {},
|
| 524 |
+
"outputs": [
|
| 525 |
+
{
|
| 526 |
+
"data": {
|
| 527 |
+
"application/json": {
|
| 528 |
+
"ascii": false,
|
| 529 |
+
"bar_format": null,
|
| 530 |
+
"colour": null,
|
| 531 |
+
"elapsed": 0.004486799240112305,
|
| 532 |
+
"initial": 0,
|
| 533 |
+
"n": 0,
|
| 534 |
+
"ncols": null,
|
| 535 |
+
"nrows": null,
|
| 536 |
+
"postfix": null,
|
| 537 |
+
"prefix": "Loading checkpoint shards",
|
| 538 |
+
"rate": null,
|
| 539 |
+
"total": 2,
|
| 540 |
+
"unit": "it",
|
| 541 |
+
"unit_divisor": 1000,
|
| 542 |
+
"unit_scale": false
|
| 543 |
+
},
|
| 544 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 545 |
+
"model_id": "32ee0b5e2df049a0b9e458c779e09a68",
|
| 546 |
+
"version_major": 2,
|
| 547 |
+
"version_minor": 0
|
| 548 |
+
},
|
| 549 |
+
"text/plain": [
|
| 550 |
+
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
| 551 |
+
]
|
| 552 |
+
},
|
| 553 |
+
"metadata": {},
|
| 554 |
+
"output_type": "display_data"
|
| 555 |
+
},
|
| 556 |
+
{
|
| 557 |
+
"name": "stdout",
|
| 558 |
+
"output_type": "stream",
|
| 559 |
+
"text": [
|
| 560 |
+
"\n",
|
| 561 |
+
"==== Financial Sentiment Analysis ====\n",
|
| 562 |
+
"\n",
|
| 563 |
+
"Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
|
| 564 |
+
"Input: Glaxo's ViiV Healthcare Signs China Manufacturing Deal With Desano\n",
|
| 565 |
+
"Answer: positive\n",
|
| 566 |
+
"\n",
|
| 567 |
+
"==== Financial Relation Extraction ====\n",
|
| 568 |
+
"\n",
|
| 569 |
+
"Instruction: Given phrases that describe the relationship between two words/phrases as options, extract the word/phrase pair and the corresponding lexical relationship between them from the input text. The output format should be \"relation1: word1, word2; relation2: word3, word4\". Options: product/material produced, manufacturer, distributed by, industry, position held, original broadcaster, owned by, founded by, distribution format, headquarters location, stock exchange, currency, parent organization, chief executive officer, director/manager, owner of, operator, member of, employer, chairperson, platform, subsidiary, legal form, publisher, developer, brand, business division, location of formation, creator.\n",
|
| 570 |
+
"Input: Wednesday, July 8, 2015 10:30AM IST (5:00AM GMT) Rimini Street Comment on Oracle Litigation Las Vegas, United States Rimini Street, Inc., the leading independent provider of enterprise software support for SAP AG’s (NYSE:SAP) Business Suite and BusinessObjects software and Oracle Corporation’s (NYSE:ORCL) Siebel , PeopleSoft , JD Edwards , E-Business Suite , Oracle Database , Hyperion and Oracle Retail software, today issued a statement on the Oracle litigation.\n",
|
| 571 |
+
"Answer: product_or_material_produced: software provider, Software\n",
|
| 572 |
+
"\n",
|
| 573 |
+
"==== Financial Headline Classification ====\n",
|
| 574 |
+
"\n",
|
| 575 |
+
"Instruction: Does the news headline talk about price in the past? Please choose an answer from {Yes/No}.\n",
|
| 576 |
+
"Input: april gold down 20 cents to settle at $1,116.10/oz\n",
|
| 577 |
+
"Answer: Yes\n",
|
| 578 |
+
"\n",
|
| 579 |
+
"==== Financial Named Entity Recognition ====\n",
|
| 580 |
+
"\n",
|
| 581 |
+
"Instruction: Please extract entities and their types from the input sentence, entity types should be chosen from {person/organization/location}.\n",
|
| 582 |
+
"Input: Subject to the terms and conditions of this Agreement , Bank agrees to lend to Borrower , from time to time prior to the Commitment Termination Date , equipment advances ( each an \" Equipment Advance \" and collectively the \" Equipment Advances \").\n",
|
| 583 |
+
"Answer: Bank is an organization, Borrower is a person.\n"
|
| 584 |
+
]
|
| 585 |
+
}
|
| 586 |
+
],
|
| 587 |
+
"source": [
|
| 588 |
+
"base_model = 'bloom'\n",
|
| 589 |
+
"peft_model = 'FinGPT/fingpt-mt_bloom-7b1_lora' if FROM_REMOTE else 'finetuned_models/MT-bloom-linear_202309211510'\n",
|
| 590 |
+
"\n",
|
| 591 |
+
"model, tokenizer = load_model(base_model, peft_model, FROM_REMOTE)\n",
|
| 592 |
+
"test_demo(model, tokenizer)"
|
| 593 |
+
]
|
| 594 |
+
},
|
| 595 |
+
{
|
| 596 |
+
"cell_type": "markdown",
|
| 597 |
+
"metadata": {},
|
| 598 |
+
"source": [
|
| 599 |
+
"# MPT-7B"
|
| 600 |
+
]
|
| 601 |
+
},
|
| 602 |
+
{
|
| 603 |
+
"cell_type": "code",
|
| 604 |
+
"execution_count": 11,
|
| 605 |
+
"metadata": {},
|
| 606 |
+
"outputs": [
|
| 607 |
+
{
|
| 608 |
+
"name": "stderr",
|
| 609 |
+
"output_type": "stream",
|
| 610 |
+
"text": [
|
| 611 |
+
"/root/.cache/huggingface/modules/transformers_modules/mpt-7b-peft-compatible/attention.py:148: UserWarning: Using `attn_impl: torch`. If your model does not use `alibi` or `prefix_lm` we recommend using `attn_impl: flash` otherwise we recommend using `attn_impl: triton`.\n",
|
| 612 |
+
" warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')\n",
|
| 613 |
+
"The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n"
|
| 614 |
+
]
|
| 615 |
+
},
|
| 616 |
+
{
|
| 617 |
+
"data": {
|
| 618 |
+
"application/json": {
|
| 619 |
+
"ascii": false,
|
| 620 |
+
"bar_format": null,
|
| 621 |
+
"colour": null,
|
| 622 |
+
"elapsed": 0.004449605941772461,
|
| 623 |
+
"initial": 0,
|
| 624 |
+
"n": 0,
|
| 625 |
+
"ncols": null,
|
| 626 |
+
"nrows": null,
|
| 627 |
+
"postfix": null,
|
| 628 |
+
"prefix": "Loading checkpoint shards",
|
| 629 |
+
"rate": null,
|
| 630 |
+
"total": 2,
|
| 631 |
+
"unit": "it",
|
| 632 |
+
"unit_divisor": 1000,
|
| 633 |
+
"unit_scale": false
|
| 634 |
+
},
|
| 635 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 636 |
+
"model_id": "0440bc96112344c493c8a1f5dd76f319",
|
| 637 |
+
"version_major": 2,
|
| 638 |
+
"version_minor": 0
|
| 639 |
+
},
|
| 640 |
+
"text/plain": [
|
| 641 |
+
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
| 642 |
+
]
|
| 643 |
+
},
|
| 644 |
+
"metadata": {},
|
| 645 |
+
"output_type": "display_data"
|
| 646 |
+
},
|
| 647 |
+
{
|
| 648 |
+
"name": "stderr",
|
| 649 |
+
"output_type": "stream",
|
| 650 |
+
"text": [
|
| 651 |
+
"Using pad_token, but it is not set yet.\n"
|
| 652 |
+
]
|
| 653 |
+
},
|
| 654 |
+
{
|
| 655 |
+
"name": "stdout",
|
| 656 |
+
"output_type": "stream",
|
| 657 |
+
"text": [
|
| 658 |
+
"\n",
|
| 659 |
+
"==== Financial Sentiment Analysis ====\n",
|
| 660 |
+
"\n",
|
| 661 |
+
"Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
|
| 662 |
+
"Input: Glaxo's ViiV Healthcare Signs China Manufacturing Deal With Desano\n",
|
| 663 |
+
"Answer: positive\n",
|
| 664 |
+
"\n",
|
| 665 |
+
"==== Financial Relation Extraction ====\n",
|
| 666 |
+
"\n",
|
| 667 |
+
"Instruction: Given phrases that describe the relationship between two words/phrases as options, extract the word/phrase pair and the corresponding lexical relationship between them from the input text. The output format should be \"relation1: word1, word2; relation2: word3, word4\". Options: product/material produced, manufacturer, distributed by, industry, position held, original broadcaster, owned by, founded by, distribution format, headquarters location, stock exchange, currency, parent organization, chief executive officer, director/manager, owner of, operator, member of, employer, chairperson, platform, subsidiary, legal form, publisher, developer, brand, business division, location of formation, creator.\n",
|
| 668 |
+
"Input: Wednesday, July 8, 2015 10:30AM IST (5:00AM GMT) Rimini Street Comment on Oracle Litigation Las Vegas, United States Rimini Street, Inc., the leading independent provider of enterprise software support for SAP AG’s (NYSE:SAP) Business Suite and BusinessObjects software and Oracle Corporation’s (NYSE:ORCL) Siebel, PeopleSoft, JD Edwards, E-Business Suite, Oracle Database, Hyperion and Oracle Retail software, today issued a statement on the Oracle litigation.\n",
|
| 669 |
+
"Answer: product_or_material_produced: Hyperion, software\n",
|
| 670 |
+
"\n",
|
| 671 |
+
"==== Financial Headline Classification ====\n",
|
| 672 |
+
"\n",
|
| 673 |
+
"Instruction: Does the news headline talk about price in the past? Please choose an answer from {Yes/No}.\n",
|
| 674 |
+
"Input: april gold down 20 cents to settle at $1,116.10/oz\n",
|
| 675 |
+
"Answer: Yes\n",
|
| 676 |
+
"\n",
|
| 677 |
+
"==== Financial Named Entity Recognition ====\n",
|
| 678 |
+
"\n",
|
| 679 |
+
"Instruction: Please extract entities and their types from the input sentence, entity types should be chosen from {person/organization/location}.\n",
|
| 680 |
+
"Input: Subject to the terms and conditions of this Agreement, Bank agrees to lend to Borrower, from time to time prior to the Commitment Termination Date, equipment advances ( each an \" Equipment Advance \" and collectively the \" Equipment Advances \").\n",
|
| 681 |
+
"Answer: Bank is an organization, Borrower is a person.\n"
|
| 682 |
+
]
|
| 683 |
+
}
|
| 684 |
+
],
|
| 685 |
+
"source": [
|
| 686 |
+
"base_model = 'mpt'\n",
|
| 687 |
+
"peft_model = 'FinGPT/fingpt-mt_mpt-7b_lora' if FROM_REMOTE else 'finetuned_models/MT-mpt-linear_202309230221'\n",
|
| 688 |
+
"\n",
|
| 689 |
+
"model, tokenizer = load_model(base_model, peft_model, FROM_REMOTE)\n",
|
| 690 |
+
"test_demo(model, tokenizer)"
|
| 691 |
+
]
|
| 692 |
+
}
|
| 693 |
+
],
|
| 694 |
+
"metadata": {
|
| 695 |
+
"kernelspec": {
|
| 696 |
+
"display_name": "torch2",
|
| 697 |
+
"language": "python",
|
| 698 |
+
"name": "torch2"
|
| 699 |
+
},
|
| 700 |
+
"language_info": {
|
| 701 |
+
"codemirror_mode": {
|
| 702 |
+
"name": "ipython",
|
| 703 |
+
"version": 3
|
| 704 |
+
},
|
| 705 |
+
"file_extension": ".py",
|
| 706 |
+
"mimetype": "text/x-python",
|
| 707 |
+
"name": "python",
|
| 708 |
+
"nbconvert_exporter": "python",
|
| 709 |
+
"pygments_lexer": "ipython3",
|
| 710 |
+
"version": "3.9.12"
|
| 711 |
+
}
|
| 712 |
+
},
|
| 713 |
+
"nbformat": 4,
|
| 714 |
+
"nbformat_minor": 4
|
| 715 |
+
}
|
fingpt/FinGPT_Benchmark/readme.md
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FinGPT's Benchmark
|
| 2 |
+
|
| 3 |
+
[FinGPT: Instruction Tuning Benchmark for Open-Source Large Language Models in Financial Datasets
|
| 4 |
+
](https://arxiv.org/abs/2310.04793)
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
The datasets we used, and the multi-task financial LLMs models are available at <https://huggingface.co/FinGPT>
|
| 8 |
+
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
Before you start, make sure you have the correct versions of the key packages installed.
|
| 12 |
+
```
|
| 13 |
+
transformers==4.32.0
|
| 14 |
+
peft==0.5.0
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
[Weights & Biases](https://wandb.ai/site) is a good tool for tracking model training and inference, you need to register, get a free API, and create a new project.
|
| 18 |
+
|
| 19 |
+
wandb produces some nice charts like the following:
|
| 20 |
+
|
| 21 |
+
<img width="440" alt="image" src="https://github.com/AI4Finance-Foundation/FinGPT/assets/31713746/04a08b3d-58e3-47aa-8b07-3ec6ff9dfea4">
|
| 22 |
+
<img width="440" alt="image" src="https://github.com/AI4Finance-Foundation/FinGPT/assets/31713746/f207a64b-622d-4a41-8e0f-1959a2d25450">
|
| 23 |
+
<img width="440" alt="image" src="https://github.com/AI4Finance-Foundation/FinGPT/assets/31713746/e7699c64-7c3c-4130-94b3-59688631120a">
|
| 24 |
+
<img width="440" alt="image" src="https://github.com/AI4Finance-Foundation/FinGPT/assets/31713746/65ca7853-3d33-4856-80e5-f03476efcc78">
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
## Ready-to-use Demo
|
| 28 |
+
|
| 29 |
+
For users who want ready-to-use financial multi-task language models, please refer to `demo.ipynb`.
|
| 30 |
+
Following this notebook, you're able to test Llama2-7B, ChatGLM2-6B, MPT-7B, BLOOM-7B, Falcon-7B, or Qwen-7B with any of the following tasks:
|
| 31 |
+
- Financial Sentiment Analysis
|
| 32 |
+
- Headline Classification
|
| 33 |
+
- Named Entity Recognition
|
| 34 |
+
- Financial Relation Extraction
|
| 35 |
+
|
| 36 |
+
We suggest users follow the instruction template and task prompts that we used in our training process. Demos are shown in `demo.ipynb`. Due to the limited diversity of the financial tasks and datasets we used, models might not respond correctly to out-of-scope instructions. We'll delve into the generalization ability more in our future works.
|
| 37 |
+
|
| 38 |
+
## Prepare Data & Base Models
|
| 39 |
+
|
| 40 |
+
For the base models we used, we recommend pre-downloading them and save to `base_models/`.
|
| 41 |
+
|
| 42 |
+
Refer to the `parse_model_name()` function in `utils.py` for the huggingface models we used for each LLM. (We use base models rather than any instruction-tuned version or chat version, except for ChatGLM2)
|
| 43 |
+
|
| 44 |
+
---
|
| 45 |
+
|
| 46 |
+
For the datasets we used, download our processed instruction tuning data from huggingface. Take FinRED dataset as an example:
|
| 47 |
+
```
|
| 48 |
+
import datasets
|
| 49 |
+
|
| 50 |
+
dataset = datasets.load_dataset('FinGPT/fingpt-finred')
|
| 51 |
+
# save to local disk space (recommended)
|
| 52 |
+
dataset.save_to_disk('data/fingpt-finred')
|
| 53 |
+
```
|
| 54 |
+
Then `finred` became an available task option for training.
|
| 55 |
+
|
| 56 |
+
We use different datasets at different phases of our instruction tuning paradigm.
|
| 57 |
+
- Task-specific Instruction Tuning: `sentiment-train / finred-re / ner / headline`
|
| 58 |
+
- Multi-task Instruction Tuning: `sentiment-train & finred & ner & headline`
|
| 59 |
+
- Zero-shot Aimed Instruction Tuning: `finred-cls & ner-cls & headline-cls -> sentiment-cls (test)`
|
| 60 |
+
|
| 61 |
+
You may download the datasets according to your needs. We also provide processed datasets for ConvFinQA and FinEval, but they are not used in our final work.
|
| 62 |
+
|
| 63 |
+
### prepare data from scratch
|
| 64 |
+
To prepare training data from raw data, you should follow `data/prepate_data.ipynb`.
|
| 65 |
+
|
| 66 |
+
We don't include any source data from other open-source financial datasets in our repository. So if you want to do it from scratch, you need to find the corresponding source data and put them in `data/` before you start.
|
| 67 |
+
|
| 68 |
+
---
|
| 69 |
+
|
| 70 |
+
## Instruction Tuning
|
| 71 |
+
|
| 72 |
+
`train.sh` contains examples of instruction tuning with this repo.
|
| 73 |
+
If you don't have training data & base models in your local disk, pass `--from_remote true` in addition.
|
| 74 |
+
|
| 75 |
+
### Task-specific Instruction Tuning
|
| 76 |
+
```
|
| 77 |
+
#chatglm2
|
| 78 |
+
deepspeed train_lora.py \
|
| 79 |
+
--run_name headline-chatglm2-linear \
|
| 80 |
+
--base_model chatglm2 \
|
| 81 |
+
--dataset headline \
|
| 82 |
+
--max_length 512 \
|
| 83 |
+
--batch_size 4 \
|
| 84 |
+
--learning_rate 1e-4 \
|
| 85 |
+
--num_epochs 8
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
Please be aware that "localhost:2" refers to a particular GPU device.
|
| 89 |
+
|
| 90 |
+
```
|
| 91 |
+
#llama2-13b
|
| 92 |
+
deepspeed -i "localhost:2" train_lora.py \
|
| 93 |
+
--run_name sentiment-llama2-13b-8epoch-16batch \
|
| 94 |
+
--base_model llama2-13b-nr \
|
| 95 |
+
--dataset sentiment-train \
|
| 96 |
+
--max_length 512 \
|
| 97 |
+
--batch_size 16 \
|
| 98 |
+
--learning_rate 1e-5 \
|
| 99 |
+
--num_epochs 8 \
|
| 100 |
+
--from_remote True \
|
| 101 |
+
>train.log 2>&1 &
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
use
|
| 105 |
+
```
|
| 106 |
+
tail -f train.log
|
| 107 |
+
```
|
| 108 |
+
to check the training log
|
| 109 |
+
|
| 110 |
+
### Multi-task Instruction Tuning
|
| 111 |
+
```
|
| 112 |
+
deepspeed train_lora.py \
|
| 113 |
+
--run_name MT-falcon-linear \
|
| 114 |
+
--base_model falcon \
|
| 115 |
+
--dataset sentiment-train,headline,finred*3,ner*15 \
|
| 116 |
+
--max_length 512 \
|
| 117 |
+
--batch_size 4 \
|
| 118 |
+
--learning_rate 1e-4 \
|
| 119 |
+
--num_epochs 4
|
| 120 |
+
```
|
| 121 |
+
### Zero-shot Aimed Instruction Tuning
|
| 122 |
+
```
|
| 123 |
+
deepspeed train_lora.py \
|
| 124 |
+
--run_name GRCLS-sentiment-falcon-linear-small \
|
| 125 |
+
--base_model falcon \
|
| 126 |
+
--test_dataset sentiment-cls-instruct \
|
| 127 |
+
--dataset headline-cls-instruct,finred-cls-instruct*2,ner-cls-instruct*7 \
|
| 128 |
+
--max_length 512 \
|
| 129 |
+
--batch_size 4 \
|
| 130 |
+
--learning_rate 1e-4 \
|
| 131 |
+
--num_epochs 1 \
|
| 132 |
+
--log_interval 10 \
|
| 133 |
+
--warmup_ratio 0 \
|
| 134 |
+
--scheduler linear \
|
| 135 |
+
--evaluation_strategy steps \
|
| 136 |
+
--eval_steps 100 \
|
| 137 |
+
--ds_config config_hf.json
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
---
|
| 141 |
+
|
| 142 |
+
## Evaluation for Financial Tasks
|
| 143 |
+
|
| 144 |
+
Refer to `Benchmarks/evaluate.sh` for evaluation script on all Financial Tasks.
|
| 145 |
+
You can evaluate your trained model on multiple tasks together. For example:
|
| 146 |
+
```
|
| 147 |
+
python benchmarks.py \
|
| 148 |
+
--dataset fpb,fiqa,tfns,nwgi,headline,ner,re \
|
| 149 |
+
--base_model llama2 \
|
| 150 |
+
--peft_model ../finetuned_models/MT-llama2-linear_202309241345 \
|
| 151 |
+
--batch_size 8 \
|
| 152 |
+
--max_length 512
|
| 153 |
+
```
|
| 154 |
+
|
| 155 |
+
```
|
| 156 |
+
#llama2-13b sentiment analysis
|
| 157 |
+
CUDA_VISIBLE_DEVICES=1 python benchmarks.py \
|
| 158 |
+
--dataset fpb,fiqa,tfns,nwgi \
|
| 159 |
+
--base_model llama2-13b-nr \
|
| 160 |
+
--peft_model ../finetuned_models/sentiment-llama2-13b-8epoch-16batch_202310271908 \
|
| 161 |
+
--batch_size 8 \
|
| 162 |
+
--max_length 512 \
|
| 163 |
+
--from_remote True
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
For Zero-shot Evaluation on Sentiment Analysis, we use multiple prompts and evaluate each of them.
|
| 167 |
+
The task indicators are `fiqa_mlt` and `fpb_mlt`.
|
| 168 |
+
|
| 169 |
+
|
fingpt/FinGPT_Benchmark/train.sh
ADDED
|
@@ -0,0 +1,547 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
| 2 |
+
export NCCL_IGNORE_DISABLED_P2P=1
|
| 3 |
+
export TRANSFORMERS_NO_ADVISORY_WARNINGS=1
|
| 4 |
+
export TOKENIZERS_PARALLELISM=0
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
#---- Generalization ----
|
| 9 |
+
|
| 10 |
+
# deepspeed train_lora.py \
|
| 11 |
+
# --run_name GRCLS-sentiment-chatglm2-linear-1e-4lr \
|
| 12 |
+
# --base_model chatglm2 \
|
| 13 |
+
# --dataset headline-cls-instruct,finred-cls-instruct*2,ner-cls-instruct*7 \
|
| 14 |
+
# --test_dataset sentiment-cls-instruct \
|
| 15 |
+
# --max_length 512 \
|
| 16 |
+
# --batch_size 4 \
|
| 17 |
+
# --learning_rate 1e-4 \
|
| 18 |
+
# --num_epochs 1 \
|
| 19 |
+
# --log_interval 10 \
|
| 20 |
+
# --warmup_ratio 0.03 \
|
| 21 |
+
# --scheduler linear \
|
| 22 |
+
# --evaluation_strategy steps \
|
| 23 |
+
# --ds_config config_hf.json
|
| 24 |
+
|
| 25 |
+
# deepspeed train_lora.py \
|
| 26 |
+
# --run_name GRCLS-sentiment-llama2-linear-small \
|
| 27 |
+
# --base_model llama2 \
|
| 28 |
+
# --test_dataset sentiment-cls-instruct \
|
| 29 |
+
# --dataset headline-cls-instruct,finred-cls-instruct*2,ner-cls-instruct*7 \
|
| 30 |
+
# --max_length 512 \
|
| 31 |
+
# --batch_size 4 \
|
| 32 |
+
# --learning_rate 1e-4 \
|
| 33 |
+
# --num_epochs 1 \
|
| 34 |
+
# --log_interval 10 \
|
| 35 |
+
# --warmup_ratio 0 \
|
| 36 |
+
# --scheduler linear \
|
| 37 |
+
# --evaluation_strategy steps \
|
| 38 |
+
# --eval_steps 100 \
|
| 39 |
+
# --ds_config config_hf.json
|
| 40 |
+
|
| 41 |
+
# deepspeed train_lora.py \
|
| 42 |
+
# --run_name GRCLS-sentiment-falcon-linear-small \
|
| 43 |
+
# --base_model falcon \
|
| 44 |
+
# --test_dataset sentiment-cls-instruct \
|
| 45 |
+
# --dataset headline-cls-instruct,finred-cls-instruct*2,ner-cls-instruct*7 \
|
| 46 |
+
# --max_length 512 \
|
| 47 |
+
# --batch_size 4 \
|
| 48 |
+
# --learning_rate 1e-4 \
|
| 49 |
+
# --num_epochs 1 \
|
| 50 |
+
# --log_interval 10 \
|
| 51 |
+
# --warmup_ratio 0 \
|
| 52 |
+
# --scheduler linear \
|
| 53 |
+
# --evaluation_strategy steps \
|
| 54 |
+
# --eval_steps 100 \
|
| 55 |
+
# --ds_config config_hf.json
|
| 56 |
+
|
| 57 |
+
# deepspeed train_lora.py \
|
| 58 |
+
# --run_name GRCLS-sentiment-qwen-linear-small \
|
| 59 |
+
# --base_model qwen \
|
| 60 |
+
# --test_dataset sentiment-cls-instruct \
|
| 61 |
+
# --dataset headline-cls-instruct,finred-cls-instruct*2,ner-cls-instruct*7 \
|
| 62 |
+
# --max_length 512 \
|
| 63 |
+
# --batch_size 4 \
|
| 64 |
+
# --learning_rate 1e-4 \
|
| 65 |
+
# --num_epochs 1 \
|
| 66 |
+
# --log_interval 10 \
|
| 67 |
+
# --warmup_ratio 0 \
|
| 68 |
+
# --scheduler linear \
|
| 69 |
+
# --evaluation_strategy steps \
|
| 70 |
+
# --eval_steps 100 \
|
| 71 |
+
# --ds_config config_hf.json
|
| 72 |
+
|
| 73 |
+
# deepspeed train_lora.py \
|
| 74 |
+
# --run_name GRCLS-sentiment-bloom-linear-small \
|
| 75 |
+
# --base_model bloom \
|
| 76 |
+
# --test_dataset sentiment-cls-instruct \
|
| 77 |
+
# --dataset headline-cls-instruct,finred-cls-instruct*2,ner-cls-instruct*7 \
|
| 78 |
+
# --max_length 512 \
|
| 79 |
+
# --batch_size 4 \
|
| 80 |
+
# --learning_rate 1e-4 \
|
| 81 |
+
# --num_epochs 1 \
|
| 82 |
+
# --log_interval 10 \
|
| 83 |
+
# --warmup_ratio 0 \
|
| 84 |
+
# --scheduler linear \
|
| 85 |
+
# --evaluation_strategy steps \
|
| 86 |
+
# --eval_steps 100 \
|
| 87 |
+
# --ds_config config_hf.json
|
| 88 |
+
|
| 89 |
+
# deepspeed train_lora.py \
|
| 90 |
+
# --run_name GRCLS-sentiment-mpt-linear-small \
|
| 91 |
+
# --base_model mpt \
|
| 92 |
+
# --dataset headline-cls-instruct,finred-cls-instruct*2,ner-cls-instruct*7 \
|
| 93 |
+
# --test_dataset sentiment-cls-instruct \
|
| 94 |
+
# --max_length 512 \
|
| 95 |
+
# --batch_size 4 \
|
| 96 |
+
# --learning_rate 1e-4 \
|
| 97 |
+
# --num_epochs 1 \
|
| 98 |
+
# --log_interval 10 \
|
| 99 |
+
# --warmup_ratio 0.03 \
|
| 100 |
+
# --scheduler linear \
|
| 101 |
+
# --evaluation_strategy steps \
|
| 102 |
+
# --eval_steps 100 \
|
| 103 |
+
# --ds_config config_hf.json
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
#---- Multi-Task ----
|
| 107 |
+
|
| 108 |
+
# deepspeed train_lora.py \
|
| 109 |
+
# --run_name MT-chatglm2-linear \
|
| 110 |
+
# --base_model chatglm2 \
|
| 111 |
+
# --dataset sentiment-train,headline,finred*3,ner*15 \
|
| 112 |
+
# --max_length 512 \
|
| 113 |
+
# --batch_size 4 \
|
| 114 |
+
# --learning_rate 1e-4 \
|
| 115 |
+
# --num_epochs 4
|
| 116 |
+
|
| 117 |
+
# deepspeed train_lora.py \
|
| 118 |
+
# --run_name MT-falcon-linear \
|
| 119 |
+
# --base_model falcon \
|
| 120 |
+
# --dataset sentiment-train,headline,finred*3,ner*15 \
|
| 121 |
+
# --max_length 512 \
|
| 122 |
+
# --batch_size 4 \
|
| 123 |
+
# --learning_rate 1e-4 \
|
| 124 |
+
# --num_epochs 4
|
| 125 |
+
|
| 126 |
+
# deepspeed train_lora.py \
|
| 127 |
+
# --run_name MT-qwen-linear \
|
| 128 |
+
# --base_model qwen \
|
| 129 |
+
# --dataset sentiment-train,headline,finred*3,ner*15 \
|
| 130 |
+
# --max_length 512 \
|
| 131 |
+
# --batch_size 4 \
|
| 132 |
+
# --learning_rate 1e-4 \
|
| 133 |
+
# --num_epochs 4
|
| 134 |
+
|
| 135 |
+
# deepspeed train_lora.py \
|
| 136 |
+
# --run_name MT-mpt-linear \
|
| 137 |
+
# --base_model mpt \
|
| 138 |
+
# --dataset sentiment-train,headline,finred*3,ner*15 \
|
| 139 |
+
# --max_length 512 \
|
| 140 |
+
# --batch_size 4 \
|
| 141 |
+
# --learning_rate 1e-4 \
|
| 142 |
+
# --num_epochs 4
|
| 143 |
+
|
| 144 |
+
# deepspeed train_lora.py \
|
| 145 |
+
# --run_name MT-bloom-linear \
|
| 146 |
+
# --base_model bloom \
|
| 147 |
+
# --dataset sentiment-train,headline,finred*3,ner*15 \
|
| 148 |
+
# --max_length 512 \
|
| 149 |
+
# --batch_size 4 \
|
| 150 |
+
# --learning_rate 1e-4 \
|
| 151 |
+
# --num_epochs 4
|
| 152 |
+
|
| 153 |
+
# deepspeed train_lora.py \
|
| 154 |
+
# --run_name MT-llama2-linear \
|
| 155 |
+
# --base_model llama2 \
|
| 156 |
+
# --dataset sentiment-train,headline,finred*3,ner*15 \
|
| 157 |
+
# --max_length 512 \
|
| 158 |
+
# --batch_size 4 \
|
| 159 |
+
# --learning_rate 1e-4 \
|
| 160 |
+
# --num_epochs 4 \
|
| 161 |
+
# --log_interval 10
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
#---- FinEval ----
|
| 165 |
+
|
| 166 |
+
# deepspeed train_lora.py \
|
| 167 |
+
# --run_name fineval-internlm-linear \
|
| 168 |
+
# --base_model internlm \
|
| 169 |
+
# --dataset data/fingpt-fineval \
|
| 170 |
+
# --max_length 512 \
|
| 171 |
+
# --batch_size 4 \
|
| 172 |
+
# --learning_rate 1e-4 \
|
| 173 |
+
# --num_epochs 50 \
|
| 174 |
+
# --log_interval 10
|
| 175 |
+
|
| 176 |
+
# deepspeed train_lora.py \
|
| 177 |
+
# --run_name fineval-llama2-linear \
|
| 178 |
+
# --base_model llama2 \
|
| 179 |
+
# --dataset data/fingpt-fineval \
|
| 180 |
+
# --max_length 512 \
|
| 181 |
+
# --batch_size 4 \
|
| 182 |
+
# --learning_rate 1e-4 \
|
| 183 |
+
# --num_epochs 50 \
|
| 184 |
+
# --log_interval 10
|
| 185 |
+
|
| 186 |
+
# deepspeed train_lora.py \
|
| 187 |
+
# --run_name fineval-chatglm2-linear \
|
| 188 |
+
# --base_model chatglm2 \
|
| 189 |
+
# --dataset data/fingpt-fineval \
|
| 190 |
+
# --max_length 512 \
|
| 191 |
+
# --batch_size 4 \
|
| 192 |
+
# --learning_rate 1e-4 \
|
| 193 |
+
# --num_epochs 50 \
|
| 194 |
+
# --log_interval 10
|
| 195 |
+
|
| 196 |
+
# deepspeed train_lora.py \
|
| 197 |
+
# --run_name fineval-falcon-linear \
|
| 198 |
+
# --base_model falcon \
|
| 199 |
+
# --dataset data/fingpt-fineval \
|
| 200 |
+
# --max_length 512 \
|
| 201 |
+
# --batch_size 4 \
|
| 202 |
+
# --learning_rate 1e-4 \
|
| 203 |
+
# --num_epochs 50 \
|
| 204 |
+
# --log_interval 10
|
| 205 |
+
|
| 206 |
+
# deepspeed train_lora.py \
|
| 207 |
+
# --run_name fineval-qwen-linear \
|
| 208 |
+
# --base_model qwen \
|
| 209 |
+
# --dataset data/fingpt-fineval \
|
| 210 |
+
# --max_length 512 \
|
| 211 |
+
# --batch_size 4 \
|
| 212 |
+
# --learning_rate 1e-4 \
|
| 213 |
+
# --num_epochs 50 \
|
| 214 |
+
# --log_interval 10
|
| 215 |
+
|
| 216 |
+
# deepspeed train_lora.py \
|
| 217 |
+
# --run_name fineval-mpt-linear \
|
| 218 |
+
# --base_model mpt \
|
| 219 |
+
# --dataset data/fingpt-fineval \
|
| 220 |
+
# --max_length 512 \
|
| 221 |
+
# --batch_size 4 \
|
| 222 |
+
# --learning_rate 1e-4 \
|
| 223 |
+
# --num_epochs 50 \
|
| 224 |
+
# --log_interval 10
|
| 225 |
+
|
| 226 |
+
# deepspeed train_lora.py \
|
| 227 |
+
# --run_name fineval-bloom-linear \
|
| 228 |
+
# --base_model bloom \
|
| 229 |
+
# --dataset data/fingpt-fineval \
|
| 230 |
+
# --max_length 512 \
|
| 231 |
+
# --batch_size 4 \
|
| 232 |
+
# --learning_rate 1e-4 \
|
| 233 |
+
# --num_epochs 50 \
|
| 234 |
+
# --log_interval 10
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
#---- ConvFinQA ----
|
| 238 |
+
|
| 239 |
+
# deepspeed train_lora.py \
|
| 240 |
+
# --run_name convfinqa-llama2-linear \
|
| 241 |
+
# --base_model llama2 \
|
| 242 |
+
# --ds_config config_hf.json \
|
| 243 |
+
# --dataset data/fingpt-convfinqa \
|
| 244 |
+
# --max_length 2048 \
|
| 245 |
+
# --batch_size 1 \
|
| 246 |
+
# --learning_rate 1e-4 \
|
| 247 |
+
# --num_epochs 4
|
| 248 |
+
|
| 249 |
+
# deepspeed train_lora.py \
|
| 250 |
+
# --run_name convfinqa-chatglm2-linear \
|
| 251 |
+
# --base_model chatglm2 \
|
| 252 |
+
# --dataset data/fingpt-convfinqa \
|
| 253 |
+
# --max_length 2048 \
|
| 254 |
+
# --batch_size 1 \
|
| 255 |
+
# --learning_rate 1e-4 \
|
| 256 |
+
# --num_epochs 4
|
| 257 |
+
|
| 258 |
+
# deepspeed train_lora.py \
|
| 259 |
+
# --run_name convfinqa-falcon-linear \
|
| 260 |
+
# --base_model falcon \
|
| 261 |
+
# --dataset data/fingpt-convfinqa \
|
| 262 |
+
# --max_length 2048 \
|
| 263 |
+
# --batch_size 1 \
|
| 264 |
+
# --learning_rate 1e-4 \
|
| 265 |
+
# --num_epochs 4
|
| 266 |
+
|
| 267 |
+
# deepspeed train_lora.py \
|
| 268 |
+
# --run_name convfinqa-qwen-linear \
|
| 269 |
+
# --base_model qwen \
|
| 270 |
+
# --dataset data/fingpt-convfinqa \
|
| 271 |
+
# --max_length 2048 \
|
| 272 |
+
# --batch_size 1 \
|
| 273 |
+
# --learning_rate 1e-4 \
|
| 274 |
+
# --num_epochs 4
|
| 275 |
+
|
| 276 |
+
# deepspeed train_lora.py \
|
| 277 |
+
# --run_name convfinqa-mpt-linear \
|
| 278 |
+
# --base_model mpt \
|
| 279 |
+
# --dataset data/fingpt-convfinqa \
|
| 280 |
+
# --max_length 2048 \
|
| 281 |
+
# --batch_size 1 \
|
| 282 |
+
# --learning_rate 1e-4 \
|
| 283 |
+
# --num_epochs 4
|
| 284 |
+
|
| 285 |
+
# deepspeed train_lora.py \
|
| 286 |
+
# --run_name convfinqa-bloom-linear \
|
| 287 |
+
# --base_model bloom \
|
| 288 |
+
# --dataset data/fingpt-convfinqa \
|
| 289 |
+
# --max_length 2048 \
|
| 290 |
+
# --batch_size 1 \
|
| 291 |
+
# --learning_rate 1e-4 \
|
| 292 |
+
# --num_epochs 4
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
#---- NER ----
|
| 296 |
+
|
| 297 |
+
# deepspeed train_lora.py \
|
| 298 |
+
# --run_name ner-llama2-linear \
|
| 299 |
+
# --base_model llama2 \
|
| 300 |
+
# --dataset data/fingpt-ner \
|
| 301 |
+
# --ds_config config_hf.json \
|
| 302 |
+
# --max_length 512 \
|
| 303 |
+
# --batch_size 4 \
|
| 304 |
+
# --learning_rate 1e-4 \
|
| 305 |
+
# --num_epochs 100 \
|
| 306 |
+
# --log_interval 10
|
| 307 |
+
|
| 308 |
+
# deepspeed train_lora.py \
|
| 309 |
+
# --run_name ner-chatglm2-linear \
|
| 310 |
+
# --base_model chatglm2 \
|
| 311 |
+
# --dataset data/fingpt-ner \
|
| 312 |
+
# --max_length 512 \
|
| 313 |
+
# --batch_size 4 \
|
| 314 |
+
# --learning_rate 1e-4 \
|
| 315 |
+
# --num_epochs 100 \
|
| 316 |
+
# --log_interval 10
|
| 317 |
+
|
| 318 |
+
# deepspeed train_lora.py \
|
| 319 |
+
# --run_name ner-falcon-linear \
|
| 320 |
+
# --base_model falcon \
|
| 321 |
+
# --dataset data/fingpt-ner \
|
| 322 |
+
# --max_length 512 \
|
| 323 |
+
# --batch_size 4 \
|
| 324 |
+
# --learning_rate 1e-4 \
|
| 325 |
+
# --num_epochs 100 \
|
| 326 |
+
# --log_interval 10
|
| 327 |
+
|
| 328 |
+
# deepspeed train_lora.py \
|
| 329 |
+
# --run_name ner-qwen-linear \
|
| 330 |
+
# --base_model qwen \
|
| 331 |
+
# --dataset data/fingpt-ner \
|
| 332 |
+
# --max_length 512 \
|
| 333 |
+
# --batch_size 4 \
|
| 334 |
+
# --learning_rate 1e-4 \
|
| 335 |
+
# --num_epochs 100 \
|
| 336 |
+
# --log_interval 10
|
| 337 |
+
|
| 338 |
+
# deepspeed train_lora.py \
|
| 339 |
+
# --run_name ner-mpt-linear \
|
| 340 |
+
# --base_model mpt \
|
| 341 |
+
# --dataset data/fingpt-ner \
|
| 342 |
+
# --max_length 512 \
|
| 343 |
+
# --batch_size 4 \
|
| 344 |
+
# --learning_rate 1e-4 \
|
| 345 |
+
# --num_epochs 100 \
|
| 346 |
+
# --log_interval 10
|
| 347 |
+
|
| 348 |
+
# deepspeed train_lora.py \
|
| 349 |
+
# --run_name ner-bloom-linear \
|
| 350 |
+
# --base_model bloom \
|
| 351 |
+
# --dataset data/fingpt-ner \
|
| 352 |
+
# --max_length 512 \
|
| 353 |
+
# --batch_size 4 \
|
| 354 |
+
# --learning_rate 1e-4 \
|
| 355 |
+
# --num_epochs 100 \
|
| 356 |
+
# --log_interval 10
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
#---- Headline (IE) ----
|
| 360 |
+
|
| 361 |
+
# deepspeed train_lora.py \
|
| 362 |
+
# --run_name headline-internlm-linear \
|
| 363 |
+
# --base_model internlm \
|
| 364 |
+
# --dataset data/fingpt-headline \
|
| 365 |
+
# --ds_config config_hf.json \
|
| 366 |
+
# --max_length 512 \
|
| 367 |
+
# --batch_size 4 \
|
| 368 |
+
# --learning_rate 1e-4 \
|
| 369 |
+
# --num_epochs 8
|
| 370 |
+
|
| 371 |
+
# deepspeed train_lora.py \
|
| 372 |
+
# --run_name headline-llama2-linear \
|
| 373 |
+
# --base_model llama2 \
|
| 374 |
+
# --dataset data/fingpt-headline \
|
| 375 |
+
# --max_length 512 \
|
| 376 |
+
# --batch_size 4 \
|
| 377 |
+
# --learning_rate 1e-4 \
|
| 378 |
+
# --num_epochs 8
|
| 379 |
+
|
| 380 |
+
# deepspeed train_lora.py \
|
| 381 |
+
# --run_name headline-chatglm2-linear \
|
| 382 |
+
# --base_model chatglm2 \
|
| 383 |
+
# --dataset data/fingpt-headline \
|
| 384 |
+
# --max_length 512 \
|
| 385 |
+
# --batch_size 4 \
|
| 386 |
+
# --learning_rate 1e-4 \
|
| 387 |
+
# --num_epochs 8
|
| 388 |
+
|
| 389 |
+
# deepspeed train_lora.py \
|
| 390 |
+
# --run_name headline-falcon-linear \
|
| 391 |
+
# --base_model falcon \
|
| 392 |
+
# --dataset data/fingpt-headline \
|
| 393 |
+
# --max_length 512 \
|
| 394 |
+
# --batch_size 4 \
|
| 395 |
+
# --learning_rate 1e-4 \
|
| 396 |
+
# --num_epochs 8
|
| 397 |
+
|
| 398 |
+
# deepspeed train_lora.py \
|
| 399 |
+
# --run_name headline-qwen-linear \
|
| 400 |
+
# --base_model qwen \
|
| 401 |
+
# --dataset data/fingpt-headline \
|
| 402 |
+
# --max_length 512 \
|
| 403 |
+
# --batch_size 4 \
|
| 404 |
+
# --learning_rate 1e-4 \
|
| 405 |
+
# --num_epochs 8
|
| 406 |
+
|
| 407 |
+
# deepspeed train_lora.py \
|
| 408 |
+
# --run_name headline-mpt-linear \
|
| 409 |
+
# --base_model mpt \
|
| 410 |
+
# --dataset data/fingpt-headline \
|
| 411 |
+
# --max_length 512 \
|
| 412 |
+
# --batch_size 4 \
|
| 413 |
+
# --learning_rate 1e-4 \
|
| 414 |
+
# --num_epochs 8
|
| 415 |
+
|
| 416 |
+
# deepspeed train_lora.py \
|
| 417 |
+
# --run_name headline-bloom-linear \
|
| 418 |
+
# --base_model bloom \
|
| 419 |
+
# --dataset data/fingpt-headline \
|
| 420 |
+
# --max_length 512 \
|
| 421 |
+
# --batch_size 4 \
|
| 422 |
+
# --learning_rate 1e-4 \
|
| 423 |
+
# --num_epochs 8
|
| 424 |
+
|
| 425 |
+
#---- Sentiment Analysis ----
|
| 426 |
+
|
| 427 |
+
# deepspeed train_lora.py \
|
| 428 |
+
# --run_name sentiment-internlm-linear \
|
| 429 |
+
# --base_model internlm \
|
| 430 |
+
# --dataset data/fingpt-sentiment-train \
|
| 431 |
+
# --max_length 512 \
|
| 432 |
+
# --batch_size 4 \
|
| 433 |
+
# --learning_rate 1e-4 \
|
| 434 |
+
# --num_epochs 8
|
| 435 |
+
|
| 436 |
+
# deepspeed train_lora.py \
|
| 437 |
+
# --run_name sentiment-llama2-linear \
|
| 438 |
+
# --base_model llama2 \
|
| 439 |
+
# --dataset data/fingpt-sentiment-train \
|
| 440 |
+
# --ds_config config_hf.json \
|
| 441 |
+
# --max_length 512 \
|
| 442 |
+
# --batch_size 4 \
|
| 443 |
+
# --learning_rate 1e-4 \
|
| 444 |
+
# --num_epochs 8
|
| 445 |
+
|
| 446 |
+
# deepspeed train_lora.py \
|
| 447 |
+
# --run_name sentiment-chatglm2-linear \
|
| 448 |
+
# --base_model chatglm2 \
|
| 449 |
+
# --dataset data/fingpt-sentiment-train \
|
| 450 |
+
# --max_length 512 \
|
| 451 |
+
# --batch_size 4 \
|
| 452 |
+
# --learning_rate 1e-4 \
|
| 453 |
+
# --num_epochs 8
|
| 454 |
+
|
| 455 |
+
# deepspeed train_lora.py \
|
| 456 |
+
# --run_name sentiment-falcon-linear \
|
| 457 |
+
# --base_model falcon \
|
| 458 |
+
# --dataset data/fingpt-sentiment-train \
|
| 459 |
+
# --max_length 512 \
|
| 460 |
+
# --batch_size 4 \
|
| 461 |
+
# --learning_rate 1e-4 \
|
| 462 |
+
# --num_epochs 8
|
| 463 |
+
|
| 464 |
+
# deepspeed train_lora.py \
|
| 465 |
+
# --run_name sentiment-qwen-linear \
|
| 466 |
+
# --base_model qwen \
|
| 467 |
+
# --dataset data/fingpt-sentiment-train \
|
| 468 |
+
# --max_length 512 \
|
| 469 |
+
# --batch_size 4 \
|
| 470 |
+
# --learning_rate 1e-4 \
|
| 471 |
+
# --num_epochs 8
|
| 472 |
+
|
| 473 |
+
# deepspeed train_lora.py \
|
| 474 |
+
# --run_name sentiment-mpt-linear \
|
| 475 |
+
# --base_model mpt \
|
| 476 |
+
# --dataset data/fingpt-sentiment-train \
|
| 477 |
+
# --max_length 512 \
|
| 478 |
+
# --batch_size 4 \
|
| 479 |
+
# --learning_rate 1e-4 \
|
| 480 |
+
# --num_epochs 8
|
| 481 |
+
|
| 482 |
+
# deepspeed train_lora.py \
|
| 483 |
+
# --run_name sentiment-bloom-linear \
|
| 484 |
+
# --base_model bloom \
|
| 485 |
+
# --dataset data/fingpt-sentiment-train \
|
| 486 |
+
# --max_length 512 \
|
| 487 |
+
# --batch_size 4 \
|
| 488 |
+
# --learning_rate 1e-4 \
|
| 489 |
+
# --num_epochs 8
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
#---- Relation Extraction ----
|
| 493 |
+
|
| 494 |
+
# deepspeed train_lora.py \
|
| 495 |
+
# --run_name finred-llama2-linear \
|
| 496 |
+
# --base_model llama2 \
|
| 497 |
+
# --dataset data/fingpt-finred-re \
|
| 498 |
+
# --ds_config config_hf.json \
|
| 499 |
+
# --max_length 512 \
|
| 500 |
+
# --batch_size 4 \
|
| 501 |
+
# --learning_rate 1e-4 \
|
| 502 |
+
# --num_epochs 8
|
| 503 |
+
|
| 504 |
+
# deepspeed train_lora.py \
|
| 505 |
+
# --run_name finred-chatglm2-linear \
|
| 506 |
+
# --base_model chatglm2 \
|
| 507 |
+
# --dataset data/fingpt-finred-re \
|
| 508 |
+
# --max_length 512 \
|
| 509 |
+
# --batch_size 4 \
|
| 510 |
+
# --learning_rate 1e-4 \
|
| 511 |
+
# --num_epochs 8
|
| 512 |
+
|
| 513 |
+
# deepspeed train_lora.py \
|
| 514 |
+
# --run_name finred-falcon-linear \
|
| 515 |
+
# --base_model falcon \
|
| 516 |
+
# --dataset data/fingpt-finred-re \
|
| 517 |
+
# --max_length 512 \
|
| 518 |
+
# --batch_size 4 \
|
| 519 |
+
# --learning_rate 1e-4 \
|
| 520 |
+
# --num_epochs 8
|
| 521 |
+
|
| 522 |
+
# deepspeed train_lora.py \
|
| 523 |
+
# --run_name finred-qwen-linear \
|
| 524 |
+
# --base_model qwen \
|
| 525 |
+
# --dataset data/fingpt-finred-re \
|
| 526 |
+
# --max_length 512 \
|
| 527 |
+
# --batch_size 4 \
|
| 528 |
+
# --learning_rate 1e-4 \
|
| 529 |
+
# --num_epochs 8
|
| 530 |
+
|
| 531 |
+
# deepspeed train_lora.py \
|
| 532 |
+
# --run_name finred-mpt-linear \
|
| 533 |
+
# --base_model mpt \
|
| 534 |
+
# --dataset data/fingpt-finred-re \
|
| 535 |
+
# --max_length 512 \
|
| 536 |
+
# --batch_size 4 \
|
| 537 |
+
# --learning_rate 1e-4 \
|
| 538 |
+
# --num_epochs 8
|
| 539 |
+
|
| 540 |
+
# deepspeed train_lora.py \
|
| 541 |
+
# --run_name finred-bloom-linear \
|
| 542 |
+
# --base_model bloom \
|
| 543 |
+
# --dataset data/fingpt-finred-re \
|
| 544 |
+
# --max_length 512 \
|
| 545 |
+
# --batch_size 4 \
|
| 546 |
+
# --learning_rate 1e-4 \
|
| 547 |
+
# --num_epochs 8
|
fingpt/FinGPT_Benchmark/train_lora.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import argparse
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from functools import partial
|
| 6 |
+
import datasets
|
| 7 |
+
import torch
|
| 8 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 9 |
+
import wandb
|
| 10 |
+
from transformers import (
|
| 11 |
+
AutoTokenizer,
|
| 12 |
+
AutoModel,
|
| 13 |
+
AutoModelForCausalLM,
|
| 14 |
+
TrainingArguments,
|
| 15 |
+
Trainer,
|
| 16 |
+
DataCollatorForSeq2Seq
|
| 17 |
+
)
|
| 18 |
+
from transformers.trainer import TRAINING_ARGS_NAME
|
| 19 |
+
from transformers.integrations import TensorBoardCallback
|
| 20 |
+
# Importing LoRA specific modules
|
| 21 |
+
from peft import (
|
| 22 |
+
TaskType,
|
| 23 |
+
LoraConfig,
|
| 24 |
+
get_peft_model,
|
| 25 |
+
get_peft_model_state_dict,
|
| 26 |
+
prepare_model_for_int8_training,
|
| 27 |
+
set_peft_model_state_dict
|
| 28 |
+
)
|
| 29 |
+
from utils import *
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Replace with your own api_key and project name
|
| 33 |
+
os.environ['WANDB_API_KEY'] = 'ecf1e5e4f47441d46822d38a3249d62e8fc94db4'
|
| 34 |
+
os.environ['WANDB_PROJECT'] = 'fingpt-benchmark'
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def main(args):
|
| 38 |
+
"""
|
| 39 |
+
Main function to execute the training script.
|
| 40 |
+
|
| 41 |
+
:param args: Command line arguments
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
# Parse the model name and determine if it should be fetched from a remote source
|
| 45 |
+
model_name = parse_model_name(args.base_model, args.from_remote)
|
| 46 |
+
|
| 47 |
+
# Load the pre-trained causal language model
|
| 48 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 49 |
+
model_name,
|
| 50 |
+
# load_in_8bit=True,
|
| 51 |
+
# device_map="auto",
|
| 52 |
+
trust_remote_code=True
|
| 53 |
+
)
|
| 54 |
+
# Print model architecture for the first process in distributed training
|
| 55 |
+
if args.local_rank == 0:
|
| 56 |
+
print(model)
|
| 57 |
+
|
| 58 |
+
# Load tokenizer associated with the pre-trained model
|
| 59 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 60 |
+
|
| 61 |
+
# Apply model specific tokenization settings
|
| 62 |
+
if args.base_model != 'mpt':
|
| 63 |
+
tokenizer.padding_side = "left"
|
| 64 |
+
if args.base_model == 'qwen':
|
| 65 |
+
tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids('<|endoftext|>')
|
| 66 |
+
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids('<|extra_0|>')
|
| 67 |
+
# Ensure padding token is set correctly
|
| 68 |
+
if not tokenizer.pad_token or tokenizer.pad_token_id == tokenizer.eos_token_id:
|
| 69 |
+
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
| 70 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 71 |
+
|
| 72 |
+
# Load training and testing datasets
|
| 73 |
+
dataset_list = load_dataset(args.dataset, args.from_remote)
|
| 74 |
+
dataset_train = datasets.concatenate_datasets([d['train'] for d in dataset_list]).shuffle(seed=42)
|
| 75 |
+
|
| 76 |
+
if args.test_dataset:
|
| 77 |
+
dataset_list = load_dataset(args.test_dataset, args.from_remote)
|
| 78 |
+
dataset_test = datasets.concatenate_datasets([d['test'] for d in dataset_list])
|
| 79 |
+
|
| 80 |
+
dataset = datasets.DatasetDict({'train': dataset_train, 'test': dataset_test})
|
| 81 |
+
# Display first sample from the training dataset
|
| 82 |
+
print(dataset['train'][0])
|
| 83 |
+
# Filter out samples that exceed the maximum token length and remove unused columns
|
| 84 |
+
dataset = dataset.map(partial(tokenize, args, tokenizer))
|
| 85 |
+
print('original dataset length: ', len(dataset['train']))
|
| 86 |
+
dataset = dataset.filter(lambda x: not x['exceed_max_length'])
|
| 87 |
+
print('filtered dataset length: ', len(dataset['train']))
|
| 88 |
+
dataset = dataset.remove_columns(['instruction', 'input', 'output', 'exceed_max_length'])
|
| 89 |
+
|
| 90 |
+
print(dataset['train'][0])
|
| 91 |
+
|
| 92 |
+
# Create a timestamp for model saving
|
| 93 |
+
current_time = datetime.now()
|
| 94 |
+
formatted_time = current_time.strftime('%Y%m%d%H%M')
|
| 95 |
+
|
| 96 |
+
# Set up training arguments
|
| 97 |
+
training_args = TrainingArguments(
|
| 98 |
+
output_dir=f'finetuned_models/{args.run_name}_{formatted_time}', # 保存位置
|
| 99 |
+
logging_steps=args.log_interval,
|
| 100 |
+
num_train_epochs=args.num_epochs,
|
| 101 |
+
per_device_train_batch_size=args.batch_size,
|
| 102 |
+
per_device_eval_batch_size=args.batch_size,
|
| 103 |
+
gradient_accumulation_steps=args.gradient_steps,
|
| 104 |
+
dataloader_num_workers=args.num_workers,
|
| 105 |
+
learning_rate=args.learning_rate,
|
| 106 |
+
warmup_ratio=args.warmup_ratio,
|
| 107 |
+
lr_scheduler_type=args.scheduler,
|
| 108 |
+
save_steps=args.eval_steps,
|
| 109 |
+
eval_steps=args.eval_steps,
|
| 110 |
+
fp16=True,
|
| 111 |
+
# fp16_full_eval=True,
|
| 112 |
+
deepspeed=args.ds_config,
|
| 113 |
+
evaluation_strategy=args.evaluation_strategy,
|
| 114 |
+
load_best_model_at_end=args.load_best_model,
|
| 115 |
+
remove_unused_columns=False,
|
| 116 |
+
report_to='wandb',
|
| 117 |
+
run_name=args.run_name
|
| 118 |
+
)
|
| 119 |
+
if not args.base_model == 'mpt':
|
| 120 |
+
model.gradient_checkpointing_enable()
|
| 121 |
+
model.enable_input_require_grads()
|
| 122 |
+
model.is_parallelizable = True
|
| 123 |
+
model.model_parallel = True
|
| 124 |
+
model.config.use_cache = (
|
| 125 |
+
False
|
| 126 |
+
)
|
| 127 |
+
# model = prepare_model_for_int8_training(model
|
| 128 |
+
|
| 129 |
+
# setup peft for lora
|
| 130 |
+
peft_config = LoraConfig(
|
| 131 |
+
task_type=TaskType.CAUSAL_LM,
|
| 132 |
+
inference_mode=False,
|
| 133 |
+
r=8,
|
| 134 |
+
lora_alpha=32,
|
| 135 |
+
lora_dropout=0.1,
|
| 136 |
+
target_modules=lora_module_dict[args.base_model],
|
| 137 |
+
bias='none',
|
| 138 |
+
)
|
| 139 |
+
model = get_peft_model(model, peft_config)
|
| 140 |
+
|
| 141 |
+
# Initialize TensorBoard for logging
|
| 142 |
+
writer = SummaryWriter()
|
| 143 |
+
|
| 144 |
+
# Initialize the trainer
|
| 145 |
+
trainer = Trainer(
|
| 146 |
+
model=model,
|
| 147 |
+
args=training_args,
|
| 148 |
+
train_dataset=dataset["train"],
|
| 149 |
+
eval_dataset=dataset["test"],
|
| 150 |
+
data_collator=DataCollatorForSeq2Seq(
|
| 151 |
+
tokenizer, padding=True,
|
| 152 |
+
return_tensors="pt"
|
| 153 |
+
),
|
| 154 |
+
callbacks=[TensorBoardCallback(writer)],
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# if torch.__version__ >= "2" and sys.platform != "win32":
|
| 158 |
+
# model = torch.compile(model)
|
| 159 |
+
|
| 160 |
+
# Clear CUDA cache and start training
|
| 161 |
+
torch.cuda.empty_cache()
|
| 162 |
+
trainer.train()
|
| 163 |
+
writer.close()
|
| 164 |
+
|
| 165 |
+
# Save the fine-tuned model
|
| 166 |
+
model.save_pretrained(training_args.output_dir)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
if __name__ == "__main__":
|
| 170 |
+
# Argument parser for command line arguments
|
| 171 |
+
parser = argparse.ArgumentParser()
|
| 172 |
+
parser.add_argument("--local_rank", default=0, type=int)
|
| 173 |
+
parser.add_argument("--run_name", default='local-test', type=str)
|
| 174 |
+
parser.add_argument("--dataset", required=True, type=str)
|
| 175 |
+
parser.add_argument("--test_dataset", type=str)
|
| 176 |
+
parser.add_argument("--base_model", required=True, type=str, choices=['chatglm2', 'llama2', 'llama2-13b', 'llama2-13b-nr', 'baichuan', 'falcon', 'internlm', 'qwen', 'mpt', 'bloom'])
|
| 177 |
+
parser.add_argument("--max_length", default=512, type=int)
|
| 178 |
+
parser.add_argument("--batch_size", default=4, type=int, help="The train batch size per device")
|
| 179 |
+
parser.add_argument("--learning_rate", default=1e-4, type=float, help="The learning rate")
|
| 180 |
+
parser.add_argument("--num_epochs", default=8, type=float, help="The training epochs")
|
| 181 |
+
parser.add_argument("--gradient_steps", default=8, type=float, help="The gradient accumulation steps")
|
| 182 |
+
parser.add_argument("--num_workers", default=8, type=int, help="dataloader workers")
|
| 183 |
+
parser.add_argument("--log_interval", default=20, type=int)
|
| 184 |
+
parser.add_argument("--warmup_ratio", default=0.05, type=float)
|
| 185 |
+
parser.add_argument("--ds_config", default='./config_new.json', type=str)
|
| 186 |
+
parser.add_argument("--scheduler", default='linear', type=str)
|
| 187 |
+
parser.add_argument("--instruct_template", default='default')
|
| 188 |
+
parser.add_argument("--evaluation_strategy", default='steps', type=str)
|
| 189 |
+
parser.add_argument("--load_best_model", default='False', type=bool)
|
| 190 |
+
parser.add_argument("--eval_steps", default=0.1, type=float)
|
| 191 |
+
parser.add_argument("--from_remote", default=False, type=bool)
|
| 192 |
+
args = parser.parse_args()
|
| 193 |
+
|
| 194 |
+
# Login to Weights and Biases
|
| 195 |
+
wandb.login()
|
| 196 |
+
|
| 197 |
+
# Run the main function
|
| 198 |
+
main(args)
|
fingpt/FinGPT_Benchmark/utils.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import datasets
|
| 3 |
+
|
| 4 |
+
# A dictionary to store various prompt templates.
|
| 5 |
+
template_dict = {
|
| 6 |
+
'default': 'Instruction: {instruction}\nInput: {input}\nAnswer: '
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
# A dictionary to store the LoRA module mapping for different models.
|
| 10 |
+
lora_module_dict = {
|
| 11 |
+
'chatglm2': ['query_key_value'],
|
| 12 |
+
'falcon': ['query_key_value'],
|
| 13 |
+
'bloom': ['query_key_value'],
|
| 14 |
+
'internlm': ['q_proj', 'k_proj', 'v_proj'],
|
| 15 |
+
'llama2': ['q_proj', 'k_proj', 'v_proj'],
|
| 16 |
+
'llama2-13b': ['q_proj', 'k_proj', 'v_proj'],
|
| 17 |
+
'llama2-13b-nr': ['q_proj', 'k_proj', 'v_proj'],
|
| 18 |
+
'qwen': ["c_attn"],
|
| 19 |
+
'mpt': ['Wqkv'],
|
| 20 |
+
'baichuan': ['q_proj', 'k_proj', 'v_proj'],
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_prompt(template, instruction, input_text):
|
| 25 |
+
"""
|
| 26 |
+
Generates a prompt based on a predefined template, instruction, and input.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
template (str): The key to select the prompt template from the predefined dictionary.
|
| 30 |
+
instruction (str): The instruction text to be included in the prompt.
|
| 31 |
+
input_text (str): The input text to be included in the prompt.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
str: The generated prompt.
|
| 35 |
+
|
| 36 |
+
Raises:
|
| 37 |
+
KeyError: If the provided template key is not found in the template dictionary.
|
| 38 |
+
"""
|
| 39 |
+
if not instruction:
|
| 40 |
+
return input_text
|
| 41 |
+
|
| 42 |
+
if template not in template_dict:
|
| 43 |
+
raise KeyError(f"Template '{template}' not found. Available templates: {', '.join(template_dict.keys())}")
|
| 44 |
+
|
| 45 |
+
return template_dict[template].format(instruction=instruction, input=input_text)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def test_mapping(args, feature):
|
| 49 |
+
"""
|
| 50 |
+
Generate a mapping for testing purposes by constructing a prompt based on given instructions and input.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
args (Namespace): A namespace object that holds various configurations, including the instruction template.
|
| 54 |
+
feature (dict): A dictionary containing 'instruction' and 'input' fields used to construct the prompt.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
dict: A dictionary containing the generated prompt.
|
| 58 |
+
|
| 59 |
+
Raises:
|
| 60 |
+
ValueError: If 'instruction' or 'input' are not provided in the feature dictionary.
|
| 61 |
+
"""
|
| 62 |
+
# Ensure 'instruction' and 'input' are present in the feature dictionary.
|
| 63 |
+
if 'instruction' not in feature or 'input' not in feature:
|
| 64 |
+
raise ValueError("Both 'instruction' and 'input' need to be provided in the feature dictionary.")
|
| 65 |
+
|
| 66 |
+
# Construct the prompt using the provided instruction and input.
|
| 67 |
+
prompt = get_prompt(
|
| 68 |
+
args.instruct_template,
|
| 69 |
+
feature['instruction'],
|
| 70 |
+
feature['input']
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
return {
|
| 74 |
+
"prompt": prompt,
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
def tokenize(args, tokenizer, feature):
|
| 78 |
+
"""
|
| 79 |
+
Tokenizes the input prompt and target/output for model training or evaluation.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
args (Namespace): A namespace object containing various settings and configurations.
|
| 83 |
+
tokenizer (Tokenizer): A tokenizer object used to convert text into tokens.
|
| 84 |
+
feature (dict): A dictionary containing 'input', 'instruction', and 'output' fields.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
dict: A dictionary containing tokenized 'input_ids', 'labels', and a flag 'exceed_max_length'.
|
| 88 |
+
"""
|
| 89 |
+
# Generate the prompt.
|
| 90 |
+
prompt = get_prompt(
|
| 91 |
+
args.instruct_template,
|
| 92 |
+
feature['instruction'],
|
| 93 |
+
feature['input']
|
| 94 |
+
)
|
| 95 |
+
# Tokenize the prompt.
|
| 96 |
+
prompt_ids = tokenizer(
|
| 97 |
+
prompt,
|
| 98 |
+
padding=False,
|
| 99 |
+
max_length=args.max_length,
|
| 100 |
+
truncation=True
|
| 101 |
+
)['input_ids']
|
| 102 |
+
|
| 103 |
+
# Tokenize the target/output.
|
| 104 |
+
target_ids = tokenizer(
|
| 105 |
+
feature['output'].strip(),
|
| 106 |
+
padding=False,
|
| 107 |
+
max_length=args.max_length,
|
| 108 |
+
truncation=True,
|
| 109 |
+
add_special_tokens=False
|
| 110 |
+
)['input_ids']
|
| 111 |
+
|
| 112 |
+
# Combine tokenized prompt and target output.
|
| 113 |
+
input_ids = prompt_ids + target_ids
|
| 114 |
+
|
| 115 |
+
# Check if the combined length exceeds the maximum allowed length.
|
| 116 |
+
exceed_max_length = len(input_ids) >= args.max_length
|
| 117 |
+
|
| 118 |
+
# Add an end-of-sequence (EOS) token if it's not already present
|
| 119 |
+
# and if the sequence length is within the limit.
|
| 120 |
+
if input_ids[-1] != tokenizer.eos_token_id and not exceed_max_length:
|
| 121 |
+
input_ids.append(tokenizer.eos_token_id)
|
| 122 |
+
|
| 123 |
+
# Create label IDs for training.
|
| 124 |
+
# The labels should start from where the prompt ends, and be padded for the prompt portion.
|
| 125 |
+
label_ids = [tokenizer.pad_token_id] * len(prompt_ids) + input_ids[len(prompt_ids):]
|
| 126 |
+
|
| 127 |
+
return {
|
| 128 |
+
"input_ids": input_ids,
|
| 129 |
+
"labels": label_ids,
|
| 130 |
+
"exceed_max_length": exceed_max_length
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def parse_model_name(name, from_remote=False):
|
| 135 |
+
"""
|
| 136 |
+
Parse the model name and return the appropriate path based on whether
|
| 137 |
+
the model is to be fetched from a remote source or from a local source.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
- name (str): Name of the model.
|
| 141 |
+
- from_remote (bool): If True, return the remote path, else return the local path.
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
- str: The appropriate path for the given model name.
|
| 145 |
+
"""
|
| 146 |
+
model_paths = {
|
| 147 |
+
'chatglm2': ('THUDM/chatglm2-6b', 'base_models/chatglm2-6b'),
|
| 148 |
+
'llama2': ('meta-llama/Llama-2-7b-hf', 'base_models/Llama-2-7b-hf'),
|
| 149 |
+
'llama2-13b': ('meta-llama/Llama-2-13b-hf', 'base_models/Llama-2-13b-hf'),
|
| 150 |
+
'llama2-13b-nr': ('NousResearch/Llama-2-13b-hf', 'base_models/Llama-2-13b-hf'),
|
| 151 |
+
'falcon': ('tiiuae/falcon-7b', 'base_models/falcon-7b'),
|
| 152 |
+
'internlm': ('internlm/internlm-7b', 'base_models/internlm-7b'),
|
| 153 |
+
'qwen': ('Qwen/Qwen-7B', 'base_models/Qwen-7B'),
|
| 154 |
+
'baichuan': ('baichuan-inc/Baichuan2-7B-Base', 'base_models/Baichuan2-7B-Base'),
|
| 155 |
+
'mpt': ('cekal/mpt-7b-peft-compatible', 'base_models/mpt-7b-peft-compatible'),
|
| 156 |
+
'bloom': ('bigscience/bloom-7b1', 'base_models/bloom-7b1')
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
if name in model_paths:
|
| 160 |
+
return model_paths[name][0] if from_remote else model_paths[name][1]
|
| 161 |
+
else:
|
| 162 |
+
valid_model_names = ', '.join(model_paths.keys())
|
| 163 |
+
raise ValueError(f"Undefined base model '{name}'. Valid model names are: {valid_model_names}")
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def load_dataset(names, from_remote=False):
|
| 167 |
+
"""
|
| 168 |
+
Load one or multiple datasets based on the provided names and source location.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
names (str): A comma-separated list of dataset names. Each name can be followed by '*n' to indicate replication.
|
| 172 |
+
from_remote (bool): If True, load the dataset from Hugging Face's model hub. Otherwise, load it from a local disk.
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
List[Dataset]: A list of loaded datasets. Each dataset is possibly replicated based on the input names.
|
| 176 |
+
"""
|
| 177 |
+
# Split the dataset names by commas for handling multiple datasets
|
| 178 |
+
dataset_names = names.split(',')
|
| 179 |
+
dataset_list = []
|
| 180 |
+
|
| 181 |
+
for name in dataset_names:
|
| 182 |
+
# Initialize replication factor to 1
|
| 183 |
+
replication_factor = 1
|
| 184 |
+
dataset_name = name
|
| 185 |
+
|
| 186 |
+
# Check if the dataset name includes a replication factor
|
| 187 |
+
if '*' in name:
|
| 188 |
+
dataset_name, replication_factor = name.split('*')
|
| 189 |
+
replication_factor = int(replication_factor)
|
| 190 |
+
if replication_factor < 1:
|
| 191 |
+
raise ValueError("Replication factor must be a positive integer.")
|
| 192 |
+
|
| 193 |
+
# Construct the correct dataset path or name based on the source location
|
| 194 |
+
dataset_path_or_name = ('FinGPT/fingpt-' if from_remote else 'data/fingpt-') + dataset_name
|
| 195 |
+
if not os.path.exists(dataset_path_or_name) and not from_remote:
|
| 196 |
+
raise FileNotFoundError(f"The dataset path {dataset_path_or_name} does not exist.")
|
| 197 |
+
|
| 198 |
+
# Load the dataset
|
| 199 |
+
try:
|
| 200 |
+
tmp_dataset = datasets.load_dataset(dataset_path_or_name) if from_remote else datasets.load_from_disk(
|
| 201 |
+
dataset_path_or_name)
|
| 202 |
+
except Exception as e:
|
| 203 |
+
raise RuntimeError(f"Failed to load the dataset: {str(e)}")
|
| 204 |
+
|
| 205 |
+
# Check for 'test' split and create it from 'train' if necessary
|
| 206 |
+
if 'test' not in tmp_dataset:
|
| 207 |
+
if 'train' in tmp_dataset:
|
| 208 |
+
tmp_dataset = tmp_dataset['train']
|
| 209 |
+
tmp_dataset = tmp_dataset.train_test_split(test_size=0.2, shuffle=True, seed=42)
|
| 210 |
+
else:
|
| 211 |
+
raise ValueError("The dataset must contain a 'train' or 'test' split.")
|
| 212 |
+
|
| 213 |
+
# Append the possibly replicated dataset to the list
|
| 214 |
+
dataset_list.extend([tmp_dataset] * replication_factor)
|
| 215 |
+
|
| 216 |
+
return dataset_list
|
fingpt/FinGPT_FinancialReportAnalysis/README.md
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Financial Report Analysis Project
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
This project provides tools for analyzing financial reports, specifically annual reports (10-K), using advanced language models such as GPT-4 or other locally deployed Large Language Models (LLM). It's designed to help users generate comprehensive analysis reports in PDF format, offering insights into a company's financial health and performance over the fiscal year.
|
| 6 |
+
|
| 7 |
+
## Features
|
| 8 |
+
|
| 9 |
+
- **PDF Report Generation**: Automatically generate detailed analysis reports in PDF format for annual financial statements.
|
| 10 |
+
- **GPT-4 and LLM Support**: Utilize the power of GPT-4 or any locally deployed LLM for deep and insightful analysis.
|
| 11 |
+
- **RAG Support**: The ability to utilize the power of RAG for question-answering and summarization tasks.
|
| 12 |
+
- **Customizable Analysis**: Users can modify the analysis scope by choosing different company symbols and models.
|
| 13 |
+
- **Easy to Use**: Designed with simplicity in mind, simply run all cells in the provided notebook to get your report.
|
| 14 |
+
|
| 15 |
+
## Requirements
|
| 16 |
+
|
| 17 |
+
Before starting, ensure you have the following installed:
|
| 18 |
+
- Python 3.11 or later
|
| 19 |
+
- Jupyter Notebook
|
| 20 |
+
- Necessary Python packages (pandas, matplotlib, openai, etc.)
|
| 21 |
+
|
| 22 |
+
Obtain the sec-api (which is used to grab the 10-K report) from https://sec-api.io/profile for free.
|
| 23 |
+
|
| 24 |
+
(Optional) Obtain the fmp api for target price (paid) from https://site.financialmodelingprep.com/developer/docs/dashboard.
|
| 25 |
+
|
| 26 |
+
## Getting Started
|
| 27 |
+
|
| 28 |
+
To begin analyzing financial reports:
|
| 29 |
+
|
| 30 |
+
0. **(optional) Prepare the local LLM**:
|
| 31 |
+
If you want to run the analysis with the locally deployed models, please download Ollama and have it running: https://ollama.com/download.
|
| 32 |
+
Also, download the model you want to use in the list of available models: https://ollama.com/library with command:
|
| 33 |
+
```bash
|
| 34 |
+
ollama run <model_name>
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
1. **Open the Notebook**:
|
| 38 |
+
Launch Jupyter Notebook and open the `reportanalysis.ipynb` file:
|
| 39 |
+
```
|
| 40 |
+
jupyter notebook reportanalysis.ipynb
|
| 41 |
+
```
|
| 42 |
+
All the necessary libraries and dependencies are already imported in the notebook.
|
| 43 |
+
|
| 44 |
+
2. **Configure the Notebook**:
|
| 45 |
+
Modify the `company symbol` and `models` variables within the notebook to suit the analysis you wish to perform.
|
| 46 |
+
|
| 47 |
+
3. **Run the Analysis**:
|
| 48 |
+
Execute all cells in the notebook to generate your financial report analysis in PDF format.
|
| 49 |
+
|
| 50 |
+
## Contributing
|
| 51 |
+
|
| 52 |
+
We welcome contributions and suggestions! Please open an issue or submit a pull request with your improvements.
|
fingpt/FinGPT_FinancialReportAnalysis/reportanalysis.ipynb
ADDED
|
@@ -0,0 +1,1085 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 253,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [
|
| 8 |
+
{
|
| 9 |
+
"name": "stdout",
|
| 10 |
+
"output_type": "stream",
|
| 11 |
+
"text": [
|
| 12 |
+
"Requirement already satisfied: reportlab in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (4.1.0)\n",
|
| 13 |
+
"Requirement already satisfied: yfinance in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (0.2.37)\n",
|
| 14 |
+
"Requirement already satisfied: matplotlib in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (3.8.4)\n",
|
| 15 |
+
"Requirement already satisfied: scrapy in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (2.11.1)\n",
|
| 16 |
+
"Requirement already satisfied: sec_api in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (1.0.17)\n",
|
| 17 |
+
"Requirement already satisfied: langchain in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (0.1.14)\n",
|
| 18 |
+
"Requirement already satisfied: umap-learn in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (0.5.6)\n",
|
| 19 |
+
"Requirement already satisfied: scikit-learn in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (1.4.1.post1)\n",
|
| 20 |
+
"Requirement already satisfied: langchain_community in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (0.0.31)\n",
|
| 21 |
+
"Requirement already satisfied: tiktoken in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (0.6.0)\n",
|
| 22 |
+
"Requirement already satisfied: langchain-openai in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (0.1.1)\n",
|
| 23 |
+
"Requirement already satisfied: langchainhub in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (0.1.15)\n",
|
| 24 |
+
"Requirement already satisfied: chromadb in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (0.4.24)\n",
|
| 25 |
+
"Requirement already satisfied: langchain-anthropic in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (0.1.6)\n",
|
| 26 |
+
"Requirement already satisfied: sentence-transformers in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (2.6.1)\n",
|
| 27 |
+
"Collecting ollama\n",
|
| 28 |
+
" Downloading ollama-0.1.8-py3-none-any.whl.metadata (3.8 kB)\n",
|
| 29 |
+
"Requirement already satisfied: pillow>=9.0.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from reportlab) (10.3.0)\n",
|
| 30 |
+
"Requirement already satisfied: chardet in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from reportlab) (5.2.0)\n",
|
| 31 |
+
"Requirement already satisfied: pandas>=1.3.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from yfinance) (2.2.1)\n",
|
| 32 |
+
"Requirement already satisfied: numpy>=1.16.5 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from yfinance) (1.26.4)\n",
|
| 33 |
+
"Requirement already satisfied: requests>=2.31 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from yfinance) (2.31.0)\n",
|
| 34 |
+
"Requirement already satisfied: multitasking>=0.0.7 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from yfinance) (0.0.11)\n",
|
| 35 |
+
"Requirement already satisfied: lxml>=4.9.1 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from yfinance) (5.2.1)\n",
|
| 36 |
+
"Requirement already satisfied: appdirs>=1.4.4 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from yfinance) (1.4.4)\n",
|
| 37 |
+
"Requirement already satisfied: pytz>=2022.5 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from yfinance) (2024.1)\n",
|
| 38 |
+
"Requirement already satisfied: frozendict>=2.3.4 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from yfinance) (2.4.1)\n",
|
| 39 |
+
"Requirement already satisfied: peewee>=3.16.2 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from yfinance) (3.17.1)\n",
|
| 40 |
+
"Requirement already satisfied: beautifulsoup4>=4.11.1 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from yfinance) (4.12.3)\n",
|
| 41 |
+
"Requirement already satisfied: html5lib>=1.1 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from yfinance) (1.1)\n",
|
| 42 |
+
"Requirement already satisfied: contourpy>=1.0.1 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from matplotlib) (1.2.1)\n",
|
| 43 |
+
"Requirement already satisfied: cycler>=0.10 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from matplotlib) (0.12.1)\n",
|
| 44 |
+
"Requirement already satisfied: fonttools>=4.22.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from matplotlib) (4.51.0)\n",
|
| 45 |
+
"Requirement already satisfied: kiwisolver>=1.3.1 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from matplotlib) (1.4.5)\n",
|
| 46 |
+
"Requirement already satisfied: packaging>=20.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from matplotlib) (23.2)\n",
|
| 47 |
+
"Requirement already satisfied: pyparsing>=2.3.1 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from matplotlib) (3.1.2)\n",
|
| 48 |
+
"Requirement already satisfied: python-dateutil>=2.7 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from matplotlib) (2.9.0)\n",
|
| 49 |
+
"Requirement already satisfied: Twisted>=18.9.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from scrapy) (24.3.0)\n",
|
| 50 |
+
"Requirement already satisfied: cryptography>=36.0.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from scrapy) (42.0.5)\n",
|
| 51 |
+
"Requirement already satisfied: cssselect>=0.9.1 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from scrapy) (1.2.0)\n",
|
| 52 |
+
"Requirement already satisfied: itemloaders>=1.0.1 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from scrapy) (1.1.0)\n",
|
| 53 |
+
"Requirement already satisfied: parsel>=1.5.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from scrapy) (1.9.0)\n",
|
| 54 |
+
"Requirement already satisfied: pyOpenSSL>=21.0.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from scrapy) (24.1.0)\n",
|
| 55 |
+
"Requirement already satisfied: queuelib>=1.4.2 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from scrapy) (1.6.2)\n",
|
| 56 |
+
"Requirement already satisfied: service-identity>=18.1.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from scrapy) (24.1.0)\n",
|
| 57 |
+
"Requirement already satisfied: w3lib>=1.17.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from scrapy) (2.1.2)\n",
|
| 58 |
+
"Requirement already satisfied: zope.interface>=5.1.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from scrapy) (6.2)\n",
|
| 59 |
+
"Requirement already satisfied: protego>=0.1.15 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from scrapy) (0.3.1)\n",
|
| 60 |
+
"Requirement already satisfied: itemadapter>=0.1.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from scrapy) (0.8.0)\n",
|
| 61 |
+
"Requirement already satisfied: setuptools in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from scrapy) (68.2.2)\n",
|
| 62 |
+
"Requirement already satisfied: tldextract in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from scrapy) (5.1.2)\n",
|
| 63 |
+
"Requirement already satisfied: PyDispatcher>=2.0.5 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from scrapy) (2.0.7)\n",
|
| 64 |
+
"Requirement already satisfied: PyYAML>=5.3 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from langchain) (6.0.1)\n",
|
| 65 |
+
"Requirement already satisfied: SQLAlchemy<3,>=1.4 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from langchain) (2.0.29)\n",
|
| 66 |
+
"Requirement already satisfied: aiohttp<4.0.0,>=3.8.3 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from langchain) (3.9.3)\n",
|
| 67 |
+
"Requirement already satisfied: dataclasses-json<0.7,>=0.5.7 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from langchain) (0.6.4)\n",
|
| 68 |
+
"Requirement already satisfied: jsonpatch<2.0,>=1.33 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from langchain) (1.33)\n",
|
| 69 |
+
"Requirement already satisfied: langchain-core<0.2.0,>=0.1.37 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from langchain) (0.1.40)\n",
|
| 70 |
+
"Requirement already satisfied: langchain-text-splitters<0.1,>=0.0.1 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from langchain) (0.0.1)\n",
|
| 71 |
+
"Requirement already satisfied: langsmith<0.2.0,>=0.1.17 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from langchain) (0.1.40)\n",
|
| 72 |
+
"Requirement already satisfied: pydantic<3,>=1 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from langchain) (2.6.4)\n",
|
| 73 |
+
"Requirement already satisfied: tenacity<9.0.0,>=8.1.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from langchain) (8.2.3)\n",
|
| 74 |
+
"Requirement already satisfied: scipy>=1.3.1 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from umap-learn) (1.13.0)\n",
|
| 75 |
+
"Requirement already satisfied: numba>=0.51.2 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from umap-learn) (0.59.1)\n",
|
| 76 |
+
"Requirement already satisfied: pynndescent>=0.5 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from umap-learn) (0.5.12)\n",
|
| 77 |
+
"Requirement already satisfied: tqdm in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from umap-learn) (4.66.2)\n",
|
| 78 |
+
"Requirement already satisfied: joblib>=1.2.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from scikit-learn) (1.3.2)\n",
|
| 79 |
+
"Requirement already satisfied: threadpoolctl>=2.0.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from scikit-learn) (3.4.0)\n",
|
| 80 |
+
"Requirement already satisfied: regex>=2022.1.18 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from tiktoken) (2023.12.25)\n",
|
| 81 |
+
"Requirement already satisfied: openai<2.0.0,>=1.10.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from langchain-openai) (1.16.2)\n",
|
| 82 |
+
"Requirement already satisfied: types-requests<3.0.0.0,>=2.31.0.2 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from langchainhub) (2.31.0.20240406)\n",
|
| 83 |
+
"Requirement already satisfied: build>=1.0.3 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from chromadb) (1.2.1)\n",
|
| 84 |
+
"Requirement already satisfied: chroma-hnswlib==0.7.3 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from chromadb) (0.7.3)\n",
|
| 85 |
+
"Requirement already satisfied: fastapi>=0.95.2 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from chromadb) (0.110.1)\n",
|
| 86 |
+
"Requirement already satisfied: uvicorn>=0.18.3 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from uvicorn[standard]>=0.18.3->chromadb) (0.29.0)\n",
|
| 87 |
+
"Requirement already satisfied: posthog>=2.4.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from chromadb) (3.5.0)\n",
|
| 88 |
+
"Requirement already satisfied: typing-extensions>=4.5.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from chromadb) (4.10.0)\n",
|
| 89 |
+
"Requirement already satisfied: pulsar-client>=3.1.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from chromadb) (3.4.0)\n",
|
| 90 |
+
"Requirement already satisfied: onnxruntime>=1.14.1 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from chromadb) (1.17.1)\n",
|
| 91 |
+
"Requirement already satisfied: opentelemetry-api>=1.2.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from chromadb) (1.24.0)\n",
|
| 92 |
+
"Requirement already satisfied: opentelemetry-exporter-otlp-proto-grpc>=1.2.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from chromadb) (1.24.0)\n",
|
| 93 |
+
"Requirement already satisfied: opentelemetry-instrumentation-fastapi>=0.41b0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from chromadb) (0.45b0)\n",
|
| 94 |
+
"Requirement already satisfied: opentelemetry-sdk>=1.2.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from chromadb) (1.24.0)\n",
|
| 95 |
+
"Requirement already satisfied: tokenizers>=0.13.2 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from chromadb) (0.15.2)\n",
|
| 96 |
+
"Requirement already satisfied: pypika>=0.48.9 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from chromadb) (0.48.9)\n",
|
| 97 |
+
"Requirement already satisfied: overrides>=7.3.1 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from chromadb) (7.7.0)\n",
|
| 98 |
+
"Requirement already satisfied: importlib-resources in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from chromadb) (6.4.0)\n",
|
| 99 |
+
"Requirement already satisfied: grpcio>=1.58.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from chromadb) (1.62.1)\n",
|
| 100 |
+
"Requirement already satisfied: bcrypt>=4.0.1 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from chromadb) (4.1.2)\n",
|
| 101 |
+
"Requirement already satisfied: typer>=0.9.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from chromadb) (0.12.1)\n",
|
| 102 |
+
"Requirement already satisfied: kubernetes>=28.1.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from chromadb) (29.0.0)\n",
|
| 103 |
+
"Requirement already satisfied: mmh3>=4.0.1 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from chromadb) (4.1.0)\n",
|
| 104 |
+
"Requirement already satisfied: orjson>=3.9.12 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from chromadb) (3.10.0)\n",
|
| 105 |
+
"Requirement already satisfied: anthropic<1,>=0.23.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from langchain-anthropic) (0.23.1)\n",
|
| 106 |
+
"Requirement already satisfied: defusedxml<0.8.0,>=0.7.1 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from langchain-anthropic) (0.7.1)\n",
|
| 107 |
+
"Requirement already satisfied: transformers<5.0.0,>=4.32.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from sentence-transformers) (4.39.3)\n",
|
| 108 |
+
"Requirement already satisfied: torch>=1.11.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from sentence-transformers) (2.2.2)\n",
|
| 109 |
+
"Requirement already satisfied: huggingface-hub>=0.15.1 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from sentence-transformers) (0.22.2)\n",
|
| 110 |
+
"Requirement already satisfied: httpx<0.28.0,>=0.27.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from ollama) (0.27.0)\n",
|
| 111 |
+
"Requirement already satisfied: aiosignal>=1.1.2 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.3.1)\n",
|
| 112 |
+
"Requirement already satisfied: attrs>=17.3.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (23.2.0)\n",
|
| 113 |
+
"Requirement already satisfied: frozenlist>=1.1.1 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.4.1)\n",
|
| 114 |
+
"Requirement already satisfied: multidict<7.0,>=4.5 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (6.0.5)\n",
|
| 115 |
+
"Requirement already satisfied: yarl<2.0,>=1.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.9.4)\n",
|
| 116 |
+
"Requirement already satisfied: anyio<5,>=3.5.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from anthropic<1,>=0.23.0->langchain-anthropic) (4.3.0)\n",
|
| 117 |
+
"Requirement already satisfied: distro<2,>=1.7.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from anthropic<1,>=0.23.0->langchain-anthropic) (1.9.0)\n",
|
| 118 |
+
"Requirement already satisfied: sniffio in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from anthropic<1,>=0.23.0->langchain-anthropic) (1.3.1)\n",
|
| 119 |
+
"Requirement already satisfied: soupsieve>1.2 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from beautifulsoup4>=4.11.1->yfinance) (2.5)\n",
|
| 120 |
+
"Requirement already satisfied: pyproject_hooks in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from build>=1.0.3->chromadb) (1.0.0)\n",
|
| 121 |
+
"Requirement already satisfied: colorama in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from build>=1.0.3->chromadb) (0.4.6)\n",
|
| 122 |
+
"Requirement already satisfied: cffi>=1.12 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from cryptography>=36.0.0->scrapy) (1.16.0)\n",
|
| 123 |
+
"Requirement already satisfied: marshmallow<4.0.0,>=3.18.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from dataclasses-json<0.7,>=0.5.7->langchain) (3.21.1)\n",
|
| 124 |
+
"Requirement already satisfied: typing-inspect<1,>=0.4.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from dataclasses-json<0.7,>=0.5.7->langchain) (0.9.0)\n",
|
| 125 |
+
"Requirement already satisfied: starlette<0.38.0,>=0.37.2 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from fastapi>=0.95.2->chromadb) (0.37.2)\n",
|
| 126 |
+
"Requirement already satisfied: six>=1.9 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from html5lib>=1.1->yfinance) (1.16.0)\n",
|
| 127 |
+
"Requirement already satisfied: webencodings in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from html5lib>=1.1->yfinance) (0.5.1)\n",
|
| 128 |
+
"Requirement already satisfied: certifi in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from httpx<0.28.0,>=0.27.0->ollama) (2024.2.2)\n",
|
| 129 |
+
"Requirement already satisfied: httpcore==1.* in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from httpx<0.28.0,>=0.27.0->ollama) (1.0.5)\n",
|
| 130 |
+
"Requirement already satisfied: idna in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from httpx<0.28.0,>=0.27.0->ollama) (3.6)\n",
|
| 131 |
+
"Requirement already satisfied: h11<0.15,>=0.13 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from httpcore==1.*->httpx<0.28.0,>=0.27.0->ollama) (0.14.0)\n",
|
| 132 |
+
"Requirement already satisfied: filelock in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from huggingface-hub>=0.15.1->sentence-transformers) (3.13.3)\n",
|
| 133 |
+
"Requirement already satisfied: fsspec>=2023.5.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from huggingface-hub>=0.15.1->sentence-transformers) (2024.3.1)\n",
|
| 134 |
+
"Requirement already satisfied: jmespath>=0.9.5 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from itemloaders>=1.0.1->scrapy) (1.0.1)\n",
|
| 135 |
+
"Requirement already satisfied: jsonpointer>=1.9 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from jsonpatch<2.0,>=1.33->langchain) (2.4)\n",
|
| 136 |
+
"Requirement already satisfied: google-auth>=1.0.1 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from kubernetes>=28.1.0->chromadb) (2.29.0)\n",
|
| 137 |
+
"Requirement already satisfied: websocket-client!=0.40.0,!=0.41.*,!=0.42.*,>=0.32.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from kubernetes>=28.1.0->chromadb) (1.7.0)\n",
|
| 138 |
+
"Requirement already satisfied: requests-oauthlib in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from kubernetes>=28.1.0->chromadb) (2.0.0)\n",
|
| 139 |
+
"Requirement already satisfied: oauthlib>=3.2.2 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from kubernetes>=28.1.0->chromadb) (3.2.2)\n",
|
| 140 |
+
"Requirement already satisfied: urllib3>=1.24.2 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from kubernetes>=28.1.0->chromadb) (2.2.1)\n",
|
| 141 |
+
"Requirement already satisfied: llvmlite<0.43,>=0.42.0dev0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from numba>=0.51.2->umap-learn) (0.42.0)\n",
|
| 142 |
+
"Requirement already satisfied: coloredlogs in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from onnxruntime>=1.14.1->chromadb) (15.0.1)\n",
|
| 143 |
+
"Requirement already satisfied: flatbuffers in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from onnxruntime>=1.14.1->chromadb) (24.3.25)\n",
|
| 144 |
+
"Requirement already satisfied: protobuf in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from onnxruntime>=1.14.1->chromadb) (4.25.3)\n",
|
| 145 |
+
"Requirement already satisfied: sympy in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from onnxruntime>=1.14.1->chromadb) (1.12)\n",
|
| 146 |
+
"Requirement already satisfied: deprecated>=1.2.6 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from opentelemetry-api>=1.2.0->chromadb) (1.2.14)\n",
|
| 147 |
+
"Requirement already satisfied: importlib-metadata<=7.0,>=6.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from opentelemetry-api>=1.2.0->chromadb) (7.0.0)\n",
|
| 148 |
+
"Requirement already satisfied: googleapis-common-protos~=1.52 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from opentelemetry-exporter-otlp-proto-grpc>=1.2.0->chromadb) (1.63.0)\n",
|
| 149 |
+
"Requirement already satisfied: opentelemetry-exporter-otlp-proto-common==1.24.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from opentelemetry-exporter-otlp-proto-grpc>=1.2.0->chromadb) (1.24.0)\n",
|
| 150 |
+
"Requirement already satisfied: opentelemetry-proto==1.24.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from opentelemetry-exporter-otlp-proto-grpc>=1.2.0->chromadb) (1.24.0)\n",
|
| 151 |
+
"Requirement already satisfied: opentelemetry-instrumentation-asgi==0.45b0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from opentelemetry-instrumentation-fastapi>=0.41b0->chromadb) (0.45b0)\n",
|
| 152 |
+
"Requirement already satisfied: opentelemetry-instrumentation==0.45b0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from opentelemetry-instrumentation-fastapi>=0.41b0->chromadb) (0.45b0)\n",
|
| 153 |
+
"Requirement already satisfied: opentelemetry-semantic-conventions==0.45b0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from opentelemetry-instrumentation-fastapi>=0.41b0->chromadb) (0.45b0)\n",
|
| 154 |
+
"Requirement already satisfied: opentelemetry-util-http==0.45b0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from opentelemetry-instrumentation-fastapi>=0.41b0->chromadb) (0.45b0)\n",
|
| 155 |
+
"Requirement already satisfied: wrapt<2.0.0,>=1.0.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from opentelemetry-instrumentation==0.45b0->opentelemetry-instrumentation-fastapi>=0.41b0->chromadb) (1.16.0)\n",
|
| 156 |
+
"Requirement already satisfied: asgiref~=3.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from opentelemetry-instrumentation-asgi==0.45b0->opentelemetry-instrumentation-fastapi>=0.41b0->chromadb) (3.8.1)\n",
|
| 157 |
+
"Requirement already satisfied: tzdata>=2022.7 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from pandas>=1.3.0->yfinance) (2024.1)\n",
|
| 158 |
+
"Requirement already satisfied: monotonic>=1.5 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from posthog>=2.4.0->chromadb) (1.6)\n",
|
| 159 |
+
"Requirement already satisfied: backoff>=1.10.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from posthog>=2.4.0->chromadb) (2.2.1)\n",
|
| 160 |
+
"Requirement already satisfied: annotated-types>=0.4.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from pydantic<3,>=1->langchain) (0.6.0)\n",
|
| 161 |
+
"Requirement already satisfied: pydantic-core==2.16.3 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from pydantic<3,>=1->langchain) (2.16.3)\n",
|
| 162 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from requests>=2.31->yfinance) (3.3.2)\n",
|
| 163 |
+
"Requirement already satisfied: pyasn1 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from service-identity>=18.1.0->scrapy) (0.6.0)\n",
|
| 164 |
+
"Requirement already satisfied: pyasn1-modules in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from service-identity>=18.1.0->scrapy) (0.4.0)\n",
|
| 165 |
+
"Requirement already satisfied: greenlet!=0.4.17 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from SQLAlchemy<3,>=1.4->langchain) (3.0.3)\n",
|
| 166 |
+
"Requirement already satisfied: networkx in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from torch>=1.11.0->sentence-transformers) (3.2.1)\n",
|
| 167 |
+
"Requirement already satisfied: jinja2 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from torch>=1.11.0->sentence-transformers) (3.1.3)\n",
|
| 168 |
+
"Requirement already satisfied: safetensors>=0.4.1 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from transformers<5.0.0,>=4.32.0->sentence-transformers) (0.4.2)\n",
|
| 169 |
+
"Requirement already satisfied: automat>=0.8.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from Twisted>=18.9.0->scrapy) (22.10.0)\n",
|
| 170 |
+
"Requirement already satisfied: constantly>=15.1 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from Twisted>=18.9.0->scrapy) (23.10.4)\n",
|
| 171 |
+
"Requirement already satisfied: hyperlink>=17.1.1 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from Twisted>=18.9.0->scrapy) (21.0.0)\n",
|
| 172 |
+
"Requirement already satisfied: incremental>=22.10.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from Twisted>=18.9.0->scrapy) (22.10.0)\n",
|
| 173 |
+
"Requirement already satisfied: twisted-iocpsupport<2,>=1.0.2 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from Twisted>=18.9.0->scrapy) (1.0.4)\n",
|
| 174 |
+
"Requirement already satisfied: click>=8.0.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from typer>=0.9.0->chromadb) (8.1.7)\n",
|
| 175 |
+
"Requirement already satisfied: shellingham>=1.3.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from typer>=0.9.0->chromadb) (1.5.4)\n",
|
| 176 |
+
"Requirement already satisfied: rich>=10.11.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from typer>=0.9.0->chromadb) (13.7.1)\n",
|
| 177 |
+
"Requirement already satisfied: httptools>=0.5.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from uvicorn[standard]>=0.18.3->chromadb) (0.6.1)\n",
|
| 178 |
+
"Requirement already satisfied: python-dotenv>=0.13 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from uvicorn[standard]>=0.18.3->chromadb) (1.0.1)\n",
|
| 179 |
+
"Requirement already satisfied: watchfiles>=0.13 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from uvicorn[standard]>=0.18.3->chromadb) (0.21.0)\n",
|
| 180 |
+
"Requirement already satisfied: websockets>=10.4 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from uvicorn[standard]>=0.18.3->chromadb) (12.0)\n",
|
| 181 |
+
"Requirement already satisfied: requests-file>=1.4 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from tldextract->scrapy) (2.0.0)\n",
|
| 182 |
+
"Requirement already satisfied: pycparser in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from cffi>=1.12->cryptography>=36.0.0->scrapy) (2.22)\n",
|
| 183 |
+
"Requirement already satisfied: cachetools<6.0,>=2.0.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from google-auth>=1.0.1->kubernetes>=28.1.0->chromadb) (5.3.3)\n",
|
| 184 |
+
"Requirement already satisfied: rsa<5,>=3.1.4 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from google-auth>=1.0.1->kubernetes>=28.1.0->chromadb) (4.9)\n",
|
| 185 |
+
"Requirement already satisfied: zipp>=0.5 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from importlib-metadata<=7.0,>=6.0->opentelemetry-api>=1.2.0->chromadb) (3.17.0)\n",
|
| 186 |
+
"Requirement already satisfied: markdown-it-py>=2.2.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from rich>=10.11.0->typer>=0.9.0->chromadb) (3.0.0)\n",
|
| 187 |
+
"Requirement already satisfied: pygments<3.0.0,>=2.13.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from rich>=10.11.0->typer>=0.9.0->chromadb) (2.17.2)\n",
|
| 188 |
+
"Requirement already satisfied: mypy-extensions>=0.3.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from typing-inspect<1,>=0.4.0->dataclasses-json<0.7,>=0.5.7->langchain) (1.0.0)\n",
|
| 189 |
+
"Requirement already satisfied: humanfriendly>=9.1 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from coloredlogs->onnxruntime>=1.14.1->chromadb) (10.0)\n",
|
| 190 |
+
"Requirement already satisfied: MarkupSafe>=2.0 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from jinja2->torch>=1.11.0->sentence-transformers) (2.1.5)\n",
|
| 191 |
+
"Requirement already satisfied: mpmath>=0.19 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from sympy->onnxruntime>=1.14.1->chromadb) (1.3.0)\n",
|
| 192 |
+
"Requirement already satisfied: pyreadline3 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from humanfriendly>=9.1->coloredlogs->onnxruntime>=1.14.1->chromadb) (3.4.1)\n",
|
| 193 |
+
"Requirement already satisfied: mdurl~=0.1 in c:\\users\\boyu\\anaconda3\\envs\\py311\\lib\\site-packages (from markdown-it-py>=2.2.0->rich>=10.11.0->typer>=0.9.0->chromadb) (0.1.2)\n",
|
| 194 |
+
"Downloading ollama-0.1.8-py3-none-any.whl (9.4 kB)\n",
|
| 195 |
+
"Installing collected packages: ollama\n",
|
| 196 |
+
"Successfully installed ollama-0.1.8\n"
|
| 197 |
+
]
|
| 198 |
+
}
|
| 199 |
+
],
|
| 200 |
+
"source": [
|
| 201 |
+
"!pip install reportlab yfinance matplotlib scrapy sec_api langchain umap-learn scikit-learn langchain_community tiktoken langchain-openai langchainhub chromadb langchain-anthropic sentence-transformers openbb"
|
| 202 |
+
]
|
| 203 |
+
},
|
| 204 |
+
{
|
| 205 |
+
"cell_type": "code",
|
| 206 |
+
"execution_count": 64,
|
| 207 |
+
"metadata": {},
|
| 208 |
+
"outputs": [],
|
| 209 |
+
"source": [
|
| 210 |
+
"import yfinance as yf\n",
|
| 211 |
+
"from matplotlib import pyplot as plt\n",
|
| 212 |
+
"from pandas.tseries.offsets import DateOffset\n",
|
| 213 |
+
"from sec_api import ExtractorApi\n",
|
| 214 |
+
"import requests\n",
|
| 215 |
+
"import pandas as pd\n",
|
| 216 |
+
"import json\n",
|
| 217 |
+
"import numpy as np\n",
|
| 218 |
+
"from openai import OpenAI\n",
|
| 219 |
+
"import os\n",
|
| 220 |
+
"from utils import get_earnings_transcript, Raptor\n",
|
| 221 |
+
"from langchain_community.embeddings.sentence_transformer import (\n",
|
| 222 |
+
" SentenceTransformerEmbeddings,\n",
|
| 223 |
+
")\n",
|
| 224 |
+
"from langchain_openai import ChatOpenAI\n",
|
| 225 |
+
"from langchain_community.vectorstores import Chroma\n",
|
| 226 |
+
"from langchain_core.output_parsers import StrOutputParser\n",
|
| 227 |
+
"from langchain import hub\n",
|
| 228 |
+
"from langchain_core.runnables import RunnablePassthrough\n",
|
| 229 |
+
"from langchain_openai import OpenAIEmbeddings\n",
|
| 230 |
+
"\n",
|
| 231 |
+
"from openbb import obb\n",
|
| 232 |
+
"\n"
|
| 233 |
+
]
|
| 234 |
+
},
|
| 235 |
+
{
|
| 236 |
+
"cell_type": "markdown",
|
| 237 |
+
"metadata": {},
|
| 238 |
+
"source": [
|
| 239 |
+
"Need to be done: \n",
|
| 240 |
+
"~~1. pe ratio~~ \n",
|
| 241 |
+
"~~2. eps~~ \n",
|
| 242 |
+
"~~3. target price~~ \n",
|
| 243 |
+
"4. income growth"
|
| 244 |
+
]
|
| 245 |
+
},
|
| 246 |
+
{
|
| 247 |
+
"cell_type": "code",
|
| 248 |
+
"execution_count": 97,
|
| 249 |
+
"metadata": {},
|
| 250 |
+
"outputs": [],
|
| 251 |
+
"source": [
|
| 252 |
+
"ticker_symbol = \"TSLA\""
|
| 253 |
+
]
|
| 254 |
+
},
|
| 255 |
+
{
|
| 256 |
+
"cell_type": "code",
|
| 257 |
+
"execution_count": 102,
|
| 258 |
+
"metadata": {},
|
| 259 |
+
"outputs": [
|
| 260 |
+
{
|
| 261 |
+
"name": "stdout",
|
| 262 |
+
"output_type": "stream",
|
| 263 |
+
"text": [
|
| 264 |
+
"Using OpenAI GPT\n"
|
| 265 |
+
]
|
| 266 |
+
}
|
| 267 |
+
],
|
| 268 |
+
"source": [
|
| 269 |
+
"# define all the necessary variables\n",
|
| 270 |
+
"sec_api_key = 'YOUR_SECAPI_KEY'\n",
|
| 271 |
+
"os.environ[\"OPENAI_API_KEY\"] = \"YOUR_OPENAI_KEY\"\n",
|
| 272 |
+
"fmp_api_key = \"YOUR_FMP_API_KEY\"\n",
|
| 273 |
+
"obb.user.credentials.fmp_api_key = fmp_api_key\n",
|
| 274 |
+
"USE_CACHE = True\n",
|
| 275 |
+
"\n",
|
| 276 |
+
"llm = \"gpt-4-turbo-preview\"\n",
|
| 277 |
+
"# llm = \"qwen:1.8b\"\n",
|
| 278 |
+
"\n",
|
| 279 |
+
"# embd = OpenAIEmbeddings()\n",
|
| 280 |
+
"# create the open-source embedding function\n",
|
| 281 |
+
"embd = SentenceTransformerEmbeddings(model_name=\"all-MiniLM-L6-v2\")\n",
|
| 282 |
+
"model = ChatOpenAI(temperature=0, model=llm)\n",
|
| 283 |
+
"rag_helper = Raptor(model, embd)\n",
|
| 284 |
+
"\n",
|
| 285 |
+
"if 'gpt' in llm:\n",
|
| 286 |
+
" print(\"Using OpenAI GPT\")\n",
|
| 287 |
+
" client = OpenAI(\n",
|
| 288 |
+
" # This is the default and can be omitted\n",
|
| 289 |
+
" api_key=os.environ.get(\"OPENAI_API_KEY\"),\n",
|
| 290 |
+
" )\n",
|
| 291 |
+
"else:\n",
|
| 292 |
+
" print(\"Using local LLM, make sure you have installed Ollama (https://ollama.com/download) and have it running\")\n",
|
| 293 |
+
" client = OpenAI(\n",
|
| 294 |
+
" base_url = 'http://localhost:11434/v1',\n",
|
| 295 |
+
" api_key='ollama', # required, but unused\n",
|
| 296 |
+
" )"
|
| 297 |
+
]
|
| 298 |
+
},
|
| 299 |
+
{
|
| 300 |
+
"cell_type": "code",
|
| 301 |
+
"execution_count": 23,
|
| 302 |
+
"metadata": {},
|
| 303 |
+
"outputs": [
|
| 304 |
+
{
|
| 305 |
+
"data": {
|
| 306 |
+
"text/plain": [
|
| 307 |
+
"OBBject\n",
|
| 308 |
+
"\n",
|
| 309 |
+
"id: 0661870d-cba1-7731-8000-f4e086f3b086\n",
|
| 310 |
+
"results: [{'date': datetime.date(2023, 4, 12), 'open': 161.22000122070312, 'high': ...\n",
|
| 311 |
+
"provider: yfinance\n",
|
| 312 |
+
"warnings: None\n",
|
| 313 |
+
"chart: None\n",
|
| 314 |
+
"extra: {'metadata': {'arguments': {'provider_choices': {'provider': 'yfinance'}, 's..."
|
| 315 |
+
]
|
| 316 |
+
},
|
| 317 |
+
"execution_count": 23,
|
| 318 |
+
"metadata": {},
|
| 319 |
+
"output_type": "execute_result"
|
| 320 |
+
}
|
| 321 |
+
],
|
| 322 |
+
"source": [
|
| 323 |
+
"# openbb.stocks.ba.headlines(\"TSLA\")\n",
|
| 324 |
+
"# obb.equity.price.historical(\"AAPL\", provider=\"yfinance\")"
|
| 325 |
+
]
|
| 326 |
+
},
|
| 327 |
+
{
|
| 328 |
+
"cell_type": "code",
|
| 329 |
+
"execution_count": 74,
|
| 330 |
+
"metadata": {},
|
| 331 |
+
"outputs": [],
|
| 332 |
+
"source": [
|
| 333 |
+
"# stock.calendar"
|
| 334 |
+
]
|
| 335 |
+
},
|
| 336 |
+
{
|
| 337 |
+
"cell_type": "code",
|
| 338 |
+
"execution_count": 75,
|
| 339 |
+
"metadata": {},
|
| 340 |
+
"outputs": [],
|
| 341 |
+
"source": [
|
| 342 |
+
"# from openbb_agents.agent import openbb_agent"
|
| 343 |
+
]
|
| 344 |
+
},
|
| 345 |
+
{
|
| 346 |
+
"cell_type": "code",
|
| 347 |
+
"execution_count": 76,
|
| 348 |
+
"metadata": {},
|
| 349 |
+
"outputs": [],
|
| 350 |
+
"source": [
|
| 351 |
+
"# result = openbb_agent(\"Who are TSLA's peers? What is their respective market cap? Return the results in _descending_ order of market cap.\")"
|
| 352 |
+
]
|
| 353 |
+
},
|
| 354 |
+
{
|
| 355 |
+
"cell_type": "code",
|
| 356 |
+
"execution_count": 107,
|
| 357 |
+
"metadata": {},
|
| 358 |
+
"outputs": [],
|
| 359 |
+
"source": [
|
| 360 |
+
"# \"Develop a tailored Financial Analysis Report aligned with the user's individual needs, drawing insights from the supplied reference materials. Initiate interaction with the user to obtain essential specifics and resolve any ambiguities. Iteratively refine the Financial Analysis Report through consistent evaluations using the given evaluationRubric and gather user input to ensure the end product aligns with the users expectations. You MUST FOLLOW the rules in order.\",\"role\":\"expert level accountant\",\"department\":\"finance\",\"task\":\"Create a Financial Analysis Report\",\"task_description\":\"As an expert level accountant in the finance department, your task is to create a Financial Analysis Report that provides comprehensive insights into the financial performance and health of the company. The report should be accurate, detailed, and well-structured, showcasing key financial metrics, trends, and analysis. The finished work will be used by the management team and stakeholders to make informed decisions, identify areas of improvement, and assess the overall financial position of the company. Core success factors include attention to detail, analytical skills, and the ability to effectively communicate complex financial information. The success of the report will be measured by its ability to provide actionable recommendations and contribute to the improvement of financial decision-making processes.\"\n",
|
| 361 |
+
"\n",
|
| 362 |
+
"class ReportAnalysis:\n",
|
| 363 |
+
" def __init__(self, ticker_symbol):\n",
|
| 364 |
+
" self.ticker_symbol = ticker_symbol\n",
|
| 365 |
+
" self.stock = yf.Ticker(ticker_symbol)\n",
|
| 366 |
+
" self.info = self.stock.info\n",
|
| 367 |
+
" self.project_dir = f\"projects/{ticker_symbol}/\"\n",
|
| 368 |
+
" self.cache_dir = f\"projects/{ticker_symbol}/cache\"\n",
|
| 369 |
+
" os.makedirs(self.project_dir, exist_ok=True)\n",
|
| 370 |
+
" os.makedirs(self.cache_dir, exist_ok=True)\n",
|
| 371 |
+
" self.extractor = ExtractorApi(sec_api_key)\n",
|
| 372 |
+
" self.report_address = self.get_sec_report_address()\n",
|
| 373 |
+
" \n",
|
| 374 |
+
" self.system_prompt = \"\"\"\n",
|
| 375 |
+
" Role: Expert Investor\n",
|
| 376 |
+
" Department: Finance\n",
|
| 377 |
+
" Primary Responsibility: Generation of Customized Financial Analysis Reports\n",
|
| 378 |
+
"\n",
|
| 379 |
+
" Role Description:\n",
|
| 380 |
+
" As an Expert Investor within the finance domain, your expertise is harnessed to develop bespoke Financial Analysis Reports that cater to specific client requirements. This role demands a deep dive into financial statements and market data to unearth insights regarding a company's financial performance and stability. Engaging directly with clients to gather essential information and continuously refining the report with their feedback ensures the final product precisely meets their needs and expectations.\n",
|
| 381 |
+
"\n",
|
| 382 |
+
" Key Objectives:\n",
|
| 383 |
+
"\n",
|
| 384 |
+
" Analytical Precision: Employ meticulous analytical prowess to interpret financial data, identifying underlying trends and anomalies.\n",
|
| 385 |
+
" Effective Communication: Simplify and effectively convey complex financial narratives, making them accessible and actionable to non-specialist audiences.\n",
|
| 386 |
+
" Client Focus: Dynamically tailor reports in response to client feedback, ensuring the final analysis aligns with their strategic objectives.\n",
|
| 387 |
+
" Adherence to Excellence: Maintain the highest standards of quality and integrity in report generation, following established benchmarks for analytical rigor.\n",
|
| 388 |
+
" Performance Indicators:\n",
|
| 389 |
+
" The efficacy of the Financial Analysis Report is measured by its utility in providing clear, actionable insights. This encompasses aiding corporate decision-making, pinpointing areas for operational enhancement, and offering a lucid evaluation of the company's financial health. Success is ultimately reflected in the report's contribution to informed investment decisions and strategic planning.\n",
|
| 390 |
+
" \"\"\"\n",
|
| 391 |
+
" \n",
|
| 392 |
+
" def get_target_price(self):\n",
|
| 393 |
+
" # API URL\n",
|
| 394 |
+
" url = f\"https://financialmodelingprep.com/api/v4/price-target?symbol={self.ticker_symbol}?apikey={fmp_api_key}\"\n",
|
| 395 |
+
"\n",
|
| 396 |
+
" # 发送GET请求\n",
|
| 397 |
+
" price_target = None\n",
|
| 398 |
+
" response = requests.get(url)\n",
|
| 399 |
+
"\n",
|
| 400 |
+
" # 确保请求成功\n",
|
| 401 |
+
" if response.status_code == 200:\n",
|
| 402 |
+
" # 解析JSON数据\n",
|
| 403 |
+
" data = response.json()\n",
|
| 404 |
+
"\n",
|
| 405 |
+
" price_target = data[0]['priceTarget']\n",
|
| 406 |
+
" else:\n",
|
| 407 |
+
" print(\"Failed to retrieve data:\", response.status_code)\n",
|
| 408 |
+
" \n",
|
| 409 |
+
" return price_target\n",
|
| 410 |
+
" \n",
|
| 411 |
+
" def get_stock_performance(self):\n",
|
| 412 |
+
" def fetch_stock_data(ticker, period=\"1y\"):\n",
|
| 413 |
+
" stock = yf.Ticker(ticker)\n",
|
| 414 |
+
" hist = stock.history(period=period)\n",
|
| 415 |
+
" return hist['Close']\n",
|
| 416 |
+
" \n",
|
| 417 |
+
" target_close = fetch_stock_data(self.ticker_symbol)\n",
|
| 418 |
+
" sp500_close = fetch_stock_data(\"^GSPC\")\n",
|
| 419 |
+
"\n",
|
| 420 |
+
" # 计算变化率\n",
|
| 421 |
+
" company_change = (target_close - target_close.iloc[0]) / target_close.iloc[0] * 100\n",
|
| 422 |
+
" sp500_change = (sp500_close - sp500_close.iloc[0]) / sp500_close.iloc[0] * 100\n",
|
| 423 |
+
"\n",
|
| 424 |
+
" # 计算额外的日期点\n",
|
| 425 |
+
" start_date = company_change.index.min()\n",
|
| 426 |
+
" four_months = start_date + DateOffset(months=4)\n",
|
| 427 |
+
" eight_months = start_date + DateOffset(months=8)\n",
|
| 428 |
+
" end_date = company_change.index.max()\n",
|
| 429 |
+
"\n",
|
| 430 |
+
" # 准备绘图\n",
|
| 431 |
+
" plt.rcParams.update({'font.size': 20}) # 调整为更大的字体大小\n",
|
| 432 |
+
" plt.figure(figsize=(14, 7))\n",
|
| 433 |
+
" plt.plot(company_change.index, company_change, label=f'{self.info[\"shortName\"]} Change %', color='blue')\n",
|
| 434 |
+
" plt.plot(sp500_change.index, sp500_change, label='S&P 500 Change %', color='red')\n",
|
| 435 |
+
"\n",
|
| 436 |
+
" # 设置标题和标签\n",
|
| 437 |
+
" plt.title(f'{self.info[\"shortName\"]} vs S&P 500 - Change % Over the Past Year')\n",
|
| 438 |
+
" plt.xlabel('Date')\n",
|
| 439 |
+
" plt.ylabel('Change %')\n",
|
| 440 |
+
"\n",
|
| 441 |
+
" # 设置x轴刻度标签\n",
|
| 442 |
+
" plt.xticks([start_date, four_months, eight_months, end_date], \n",
|
| 443 |
+
" [start_date.strftime('%Y-%m'), \n",
|
| 444 |
+
" four_months.strftime('%Y-%m'), \n",
|
| 445 |
+
" eight_months.strftime('%Y-%m'), \n",
|
| 446 |
+
" end_date.strftime('%Y-%m')])\n",
|
| 447 |
+
"\n",
|
| 448 |
+
" plt.legend()\n",
|
| 449 |
+
" plt.grid(True)\n",
|
| 450 |
+
" plt.tight_layout()\n",
|
| 451 |
+
" # plt.show()\n",
|
| 452 |
+
" plot_path = f\"{self.project_dir}/stock_performance.png\"\n",
|
| 453 |
+
" plt.savefig(plot_path)\n",
|
| 454 |
+
" plt.close()\n",
|
| 455 |
+
" return plot_path\n",
|
| 456 |
+
"\n",
|
| 457 |
+
" def get_pe_eps_performance(self):\n",
|
| 458 |
+
" ss = self.get_income_stmt()\n",
|
| 459 |
+
" eps = ss.loc['Diluted EPS', :]\n",
|
| 460 |
+
"\n",
|
| 461 |
+
" # 获取过去一年的历史数据\n",
|
| 462 |
+
" historical_data = self.stock.history(period=\"5y\")\n",
|
| 463 |
+
"\n",
|
| 464 |
+
"\n",
|
| 465 |
+
" # 指定的日期,并确保它们都是UTC时区的\n",
|
| 466 |
+
" dates = pd.to_datetime(eps.index[::-1], utc=True)\n",
|
| 467 |
+
"\n",
|
| 468 |
+
" # 为了确保我们能够找到最接近的股市交易日,我们将转换日期并查找最接近的日期\n",
|
| 469 |
+
" results = {}\n",
|
| 470 |
+
" for date in dates:\n",
|
| 471 |
+
" # 如果指定日期不是交易日,使用bfill和ffill找到最近的交易日股价\n",
|
| 472 |
+
" if date not in historical_data.index:\n",
|
| 473 |
+
" close_price = historical_data.asof(date)\n",
|
| 474 |
+
" else:\n",
|
| 475 |
+
" close_price = historical_data.loc[date]\n",
|
| 476 |
+
"\n",
|
| 477 |
+
" results[date] = close_price['Close']\n",
|
| 478 |
+
"\n",
|
| 479 |
+
" \n",
|
| 480 |
+
" pe = [p/e for p, e in zip(results.values(), eps.values[::-1])]\n",
|
| 481 |
+
" dates = eps.index[::-1]\n",
|
| 482 |
+
" eps = eps.values[::-1]\n",
|
| 483 |
+
"\n",
|
| 484 |
+
" # 创建图形和轴对象\n",
|
| 485 |
+
" fig, ax1 = plt.subplots(figsize=(14, 7))\n",
|
| 486 |
+
" plt.rcParams.update({'font.size': 20}) # 调整为更大的字体大小\n",
|
| 487 |
+
"\n",
|
| 488 |
+
" # 绘制市盈率\n",
|
| 489 |
+
" color = 'tab:blue'\n",
|
| 490 |
+
" ax1.set_xlabel('Date')\n",
|
| 491 |
+
" ax1.set_ylabel('PE Ratio', color=color)\n",
|
| 492 |
+
" ax1.plot(dates, pe, color=color)\n",
|
| 493 |
+
" ax1.tick_params(axis='y', labelcolor=color)\n",
|
| 494 |
+
" ax1.grid(True)\n",
|
| 495 |
+
"\n",
|
| 496 |
+
" # 创建与ax1共享x轴的第二个轴对象\n",
|
| 497 |
+
" ax2 = ax1.twinx()\n",
|
| 498 |
+
" color = 'tab:red'\n",
|
| 499 |
+
" ax2.set_ylabel('EPS', color=color) # 第二个y轴的标签\n",
|
| 500 |
+
" ax2.plot(dates, eps, color=color)\n",
|
| 501 |
+
" ax2.tick_params(axis='y', labelcolor=color)\n",
|
| 502 |
+
"\n",
|
| 503 |
+
" # 设置标题和x轴标签角度\n",
|
| 504 |
+
" plt.title(f'{self.info[\"shortName\"]} PE Ratios and EPS Over the Past 4 Years')\n",
|
| 505 |
+
" plt.xticks(rotation=45)\n",
|
| 506 |
+
"\n",
|
| 507 |
+
" # 设置x轴刻度标签\n",
|
| 508 |
+
" plt.xticks(dates, [d.strftime('%Y-%m') for d in dates])\n",
|
| 509 |
+
"\n",
|
| 510 |
+
" plt.tight_layout()\n",
|
| 511 |
+
" # plt.show()\n",
|
| 512 |
+
" plot_path = f\"{self.project_dir}/pe_performance.png\"\n",
|
| 513 |
+
" plt.savefig(plot_path)\n",
|
| 514 |
+
" plt.close()\n",
|
| 515 |
+
" return plot_path\n",
|
| 516 |
+
" \n",
|
| 517 |
+
" def get_sec_report_address(self):\n",
|
| 518 |
+
" address_json = f\"{self.project_dir}/sec_report_address.json\"\n",
|
| 519 |
+
" if not os.path.exists(address_json):\n",
|
| 520 |
+
" endpoint = f\"https://api.sec-api.io?token={sec_api_key}\"\n",
|
| 521 |
+
"\n",
|
| 522 |
+
" # The query to find 10-K filings for a specific company\n",
|
| 523 |
+
" query = {\n",
|
| 524 |
+
" \"query\": { \"query_string\": { \"query\": f\"ticker:{self.ticker_symbol} AND formType:\\\"10-K\\\"\" } },\n",
|
| 525 |
+
" \"from\": \"0\",\n",
|
| 526 |
+
" \"size\": \"1\",\n",
|
| 527 |
+
" \"sort\": [{ \"filedAt\": { \"order\": \"desc\" } }]\n",
|
| 528 |
+
" }\n",
|
| 529 |
+
"\n",
|
| 530 |
+
" # Making the request to the SEC API\n",
|
| 531 |
+
" response = requests.post(endpoint, json=query)\n",
|
| 532 |
+
"\n",
|
| 533 |
+
" if response.status_code == 200:\n",
|
| 534 |
+
" # Parsing the response\n",
|
| 535 |
+
" filings = response.json()['filings']\n",
|
| 536 |
+
" if filings:\n",
|
| 537 |
+
" # Assuming the latest 10-K filing is what we want \n",
|
| 538 |
+
" latest_10k_url = filings[0]\n",
|
| 539 |
+
" print(f\"Latest 10-K report URL for {self.ticker_symbol}: {latest_10k_url}\")\n",
|
| 540 |
+
" else:\n",
|
| 541 |
+
" print(f\"No 10-K filings found for {self.ticker_symbol}.\")\n",
|
| 542 |
+
" else:\n",
|
| 543 |
+
" print(\"Failed to retrieve filings from SEC API.\")\n",
|
| 544 |
+
" \n",
|
| 545 |
+
" with open(address_json, \"w\") as f:\n",
|
| 546 |
+
" json.dump(latest_10k_url, f)\n",
|
| 547 |
+
" else:\n",
|
| 548 |
+
" with open(address_json, \"r\") as f:\n",
|
| 549 |
+
" latest_10k_url = json.load(f)\n",
|
| 550 |
+
"\n",
|
| 551 |
+
" return latest_10k_url['linkToFilingDetails']\n",
|
| 552 |
+
" \n",
|
| 553 |
+
" def get_key_data(self):\n",
|
| 554 |
+
" # Fetch historical market data for the past 6 months\n",
|
| 555 |
+
" hist = self.stock.history(period=\"6mo\")\n",
|
| 556 |
+
"\n",
|
| 557 |
+
" # 获取其他相关信息\n",
|
| 558 |
+
" info = self.info\n",
|
| 559 |
+
" close_price = hist['Close'].iloc[-1]\n",
|
| 560 |
+
"\n",
|
| 561 |
+
" # Calculate the average daily trading volume\n",
|
| 562 |
+
" avg_daily_volume_6m = hist['Volume'].mean()\n",
|
| 563 |
+
"\n",
|
| 564 |
+
" # Print the result\n",
|
| 565 |
+
" # print(f\"Over the past 6 months, the average daily trading volume for {ticker_symbol} was: {avg_daily_volume_6m:.2f}\")\n",
|
| 566 |
+
" result = {\n",
|
| 567 |
+
" f\"6m avg daily val ({info['currency']}mn)\": \"{:.2f}\".format(avg_daily_volume_6m/1e6),\n",
|
| 568 |
+
" f\"Closing Price ({info['currency']})\": \"{:.2f}\".format(close_price),\n",
|
| 569 |
+
" f\"Market Cap ({info['currency']}mn)\": \"{:.2f}\".format(info['marketCap']/1e6),\n",
|
| 570 |
+
" f\"52 Week Price Range ({info['currency']})\": f\"{info['fiftyTwoWeekLow']} - {info['fiftyTwoWeekHigh']}\",\n",
|
| 571 |
+
" f\"BVPS ({info['currency']})\": info['bookValue']\n",
|
| 572 |
+
" }\n",
|
| 573 |
+
" return result\n",
|
| 574 |
+
" \n",
|
| 575 |
+
" def get_company_info(self):\n",
|
| 576 |
+
" info = self.info\n",
|
| 577 |
+
" result = {\n",
|
| 578 |
+
" \"Company Name\": info['shortName'],\n",
|
| 579 |
+
" \"Industry\": info['industry'],\n",
|
| 580 |
+
" \"Sector\": info['sector'],\n",
|
| 581 |
+
" \"Country\": info['country'],\n",
|
| 582 |
+
" \"Website\": info['website']\n",
|
| 583 |
+
" }\n",
|
| 584 |
+
" return result\n",
|
| 585 |
+
" \n",
|
| 586 |
+
" def get_income_stmt(self):\n",
|
| 587 |
+
" income_stmt = self.stock.financials\n",
|
| 588 |
+
" return income_stmt\n",
|
| 589 |
+
"\n",
|
| 590 |
+
" def get_balance_sheet(self):\n",
|
| 591 |
+
" balance_sheet = self.stock.balance_sheet\n",
|
| 592 |
+
" return balance_sheet\n",
|
| 593 |
+
" \n",
|
| 594 |
+
" def get_cash_flow(self):\n",
|
| 595 |
+
" cash_flow = self.stock.cashflow\n",
|
| 596 |
+
" return cash_flow\n",
|
| 597 |
+
" \n",
|
| 598 |
+
" def get_analyst_recommendations(self):\n",
|
| 599 |
+
" recommendations = self.stock.recommendations\n",
|
| 600 |
+
" row_0 = recommendations.iloc[0, 1:] # Exclude 'period' column\n",
|
| 601 |
+
"\n",
|
| 602 |
+
" # Find the maximum voting result\n",
|
| 603 |
+
" max_votes = row_0.max()\n",
|
| 604 |
+
" majority_voting_result = row_0[row_0 == max_votes].index.tolist()\n",
|
| 605 |
+
"\n",
|
| 606 |
+
" return majority_voting_result[0], max_votes\n",
|
| 607 |
+
" \n",
|
| 608 |
+
" def get_earnings(self, quarter, year):\n",
|
| 609 |
+
" earnings = get_earnings_transcript(quarter, self.ticker_symbol, year)\n",
|
| 610 |
+
" return earnings\n",
|
| 611 |
+
" \n",
|
| 612 |
+
" def get_10k_section(self, section):\n",
|
| 613 |
+
" \"\"\"\n",
|
| 614 |
+
" Get 10-K reports from SEC EDGAR\n",
|
| 615 |
+
" \"\"\"\n",
|
| 616 |
+
" if section not in [1, \"1A\", \"1B\", 2, 3, 4, 5, 6, 7, \"7A\", 8, 9, \"9A\", \"9B\", 10, 11, 12, 13, 14, 15]:\n",
|
| 617 |
+
" raise ValueError(\"Section must be in [1, 1A, 1B, 2, 3, 4, 5, 6, 7, 7A, 8, 9, 9A, 9B, 10, 11, 12, 13, 14, 15]\")\n",
|
| 618 |
+
"\n",
|
| 619 |
+
" section = str(section)\n",
|
| 620 |
+
" os.makedirs(f\"{self.project_dir}/10k\", exist_ok=True)\n",
|
| 621 |
+
"\n",
|
| 622 |
+
" report_name = f\"{self.project_dir}/10k/section_{section}.txt\"\n",
|
| 623 |
+
"\n",
|
| 624 |
+
" if USE_CACHE and os.path.exists(report_name):\n",
|
| 625 |
+
" with open(report_name, \"r\") as f:\n",
|
| 626 |
+
" section_text = f.read()\n",
|
| 627 |
+
" else:\n",
|
| 628 |
+
" section_text = self.extractor.get_section(self.report_address, section, \"text\")\n",
|
| 629 |
+
"\n",
|
| 630 |
+
" with open(report_name, \"w\") as f:\n",
|
| 631 |
+
" f.write(section_text)\n",
|
| 632 |
+
" \n",
|
| 633 |
+
" return section_text\n",
|
| 634 |
+
" \n",
|
| 635 |
+
" def get_10k_rag(self, section):\n",
|
| 636 |
+
" # Now, use all_texts to build the vectorstore with Chroma\n",
|
| 637 |
+
" vector_dir = f\"{self.cache_dir}/section_{section}_vectorstore\"\n",
|
| 638 |
+
" if USE_CACHE and os.path.exists(vector_dir):\n",
|
| 639 |
+
" vectorstore = Chroma(persist_directory=vector_dir, embedding_function=embd)\n",
|
| 640 |
+
" vectorstore.get()\n",
|
| 641 |
+
" else:\n",
|
| 642 |
+
" section_text = self.get_10k_section(section)\n",
|
| 643 |
+
" all_texts = rag_helper.text_spliter(section_text, chunk_size_tok=2000, level=1, n_levels=3)\n",
|
| 644 |
+
"\n",
|
| 645 |
+
" vectorstore = Chroma.from_texts(texts=all_texts, embedding=embd, persist_directory=vector_dir)\n",
|
| 646 |
+
" vectorstore.persist()\n",
|
| 647 |
+
"\n",
|
| 648 |
+
" retriever = vectorstore.as_retriever()\n",
|
| 649 |
+
"\n",
|
| 650 |
+
" # Prompt\n",
|
| 651 |
+
" prompt = hub.pull(\"rlm/rag-prompt\")\n",
|
| 652 |
+
"\n",
|
| 653 |
+
" # Chain\n",
|
| 654 |
+
" rag_chain = (\n",
|
| 655 |
+
" # {\"context\": retriever | format_docs, \"question\": RunnablePassthrough()}\n",
|
| 656 |
+
" {\"context\": retriever, \"question\": RunnablePassthrough()}\n",
|
| 657 |
+
" | prompt\n",
|
| 658 |
+
" | model\n",
|
| 659 |
+
" | StrOutputParser()\n",
|
| 660 |
+
" )\n",
|
| 661 |
+
"\n",
|
| 662 |
+
" # Question\n",
|
| 663 |
+
" # rag_chain.invoke(\"What is the profit of the company. you should not say you don't know because all the required information is in the context\")\n",
|
| 664 |
+
" # rag_chain.invoke(\"Analyse the income statement of the company for the year 2023\")\n",
|
| 665 |
+
" return rag_chain\n",
|
| 666 |
+
" \n",
|
| 667 |
+
" def analyze_income_stmt(self):\n",
|
| 668 |
+
" cache_answer = f\"{self.project_dir}/income_stmt_analysis.txt\"\n",
|
| 669 |
+
" if USE_CACHE and os.path.exists(cache_answer):\n",
|
| 670 |
+
" with open(cache_answer, \"r\") as f:\n",
|
| 671 |
+
" answer = f.read()\n",
|
| 672 |
+
" else:\n",
|
| 673 |
+
" income_stmt = self.get_income_stmt()\n",
|
| 674 |
+
" df_string = \"Income statement:\" + income_stmt.to_string().strip()\n",
|
| 675 |
+
" \n",
|
| 676 |
+
" question = \"Embark on a thorough analysis of the company's income statement for the current fiscal year, focusing on revenue streams, cost of goods sold, operating expenses, and net income to discern the operational performance and profitability. Examine the gross profit margin to understand the cost efficiency, operating margin for operational effectiveness, and net profit margin to assess overall profitability. Compare these financial metrics against historical data to identify growth patterns, profitability trends, and operational challenges. Conclude with a strategic overview of the company's financial health, offering insights into revenue growth sustainability and potential areas for cost optimization and profit maximization in a single paragraph. Less than 130 words.\"\n",
|
| 677 |
+
"\n",
|
| 678 |
+
" answer = self.ask_question(question, 7, df_string, use_rag=False)\n",
|
| 679 |
+
" with open(cache_answer, \"w\") as f:\n",
|
| 680 |
+
" f.write(answer)\n",
|
| 681 |
+
" return answer\n",
|
| 682 |
+
" \n",
|
| 683 |
+
" def analyze_balance_sheet(self):\n",
|
| 684 |
+
" cache_answer = f\"{self.project_dir}/balance_sheet_analysis.txt\"\n",
|
| 685 |
+
" if USE_CACHE and os.path.exists(cache_answer):\n",
|
| 686 |
+
" with open(cache_answer, \"r\") as f:\n",
|
| 687 |
+
" answer = f.read()\n",
|
| 688 |
+
" else:\n",
|
| 689 |
+
" balance_sheet = self.get_balance_sheet()\n",
|
| 690 |
+
" df_string = \"Balance sheet:\" + balance_sheet.to_string().strip()\n",
|
| 691 |
+
" \n",
|
| 692 |
+
" question = \"Delve into a detailed scrutiny of the company's balance sheet for the most recent fiscal year, pinpointing the structure of assets, liabilities, and shareholders' equity to decode the firm's financial stability and operational efficiency. Focus on evaluating the liquidity through current assets versus current liabilities, the solvency via long-term debt ratios, and the equity position to gauge long-term investment potential. Contrast these metrics with previous years' data to highlight financial trends, improvements, or deteriorations. Finalize with a strategic assessment of the company's financial leverage, asset management, and capital structure, providing insights into its fiscal health and future prospects in a single paragraph. Less than 130 words.\"\n",
|
| 693 |
+
"\n",
|
| 694 |
+
" answer = self.ask_question(question, 7, df_string, use_rag=False)\n",
|
| 695 |
+
" with open(cache_answer, \"w\") as f:\n",
|
| 696 |
+
" f.write(answer)\n",
|
| 697 |
+
" return answer\n",
|
| 698 |
+
" \n",
|
| 699 |
+
" def analyze_cash_flow(self):\n",
|
| 700 |
+
" cache_answer = f\"{self.project_dir}/cash_flow_analysis.txt\"\n",
|
| 701 |
+
" if USE_CACHE and os.path.exists(cache_answer):\n",
|
| 702 |
+
" with open(cache_answer, \"r\") as f:\n",
|
| 703 |
+
" answer = f.read()\n",
|
| 704 |
+
" else:\n",
|
| 705 |
+
" cash_flow = self.get_cash_flow()\n",
|
| 706 |
+
" df_string = \"Balance sheet:\" + cash_flow.to_string().strip()\n",
|
| 707 |
+
" \n",
|
| 708 |
+
" question = \"Dive into a comprehensive evaluation of the company's cash flow for the latest fiscal year, focusing on cash inflows and outflows across operating, investing, and financing activities. Examine the operational cash flow to assess the core business profitability, scrutinize investing activities for insights into capital expenditures and investments, and review financing activities to understand debt, equity movements, and dividend policies. Compare these cash movements to prior periods to discern trends, sustainability, and liquidity risks. Conclude with an informed analysis of the company's cash management effectiveness, liquidity position, and potential for future growth or financial challenges in a single paragraph. Less than 130 words.\"\n",
|
| 709 |
+
"\n",
|
| 710 |
+
" answer = self.ask_question(question, 7, df_string, use_rag=False)\n",
|
| 711 |
+
" with open(cache_answer, \"w\") as f:\n",
|
| 712 |
+
" f.write(answer)\n",
|
| 713 |
+
" return answer\n",
|
| 714 |
+
" \n",
|
| 715 |
+
" def financial_summarization(self):\n",
|
| 716 |
+
" income_stmt_analysis = self.analyze_income_stmt()\n",
|
| 717 |
+
" balance_sheet_analysis = self.analyze_balance_sheet()\n",
|
| 718 |
+
" cash_flow_analysis = self.analyze_cash_flow()\n",
|
| 719 |
+
" \n",
|
| 720 |
+
" cache_answer = f\"{self.project_dir}/financial_summarization.txt\"\n",
|
| 721 |
+
" if USE_CACHE and os.path.exists(cache_answer):\n",
|
| 722 |
+
" with open(cache_answer, \"r\") as f:\n",
|
| 723 |
+
" answer = f.read()\n",
|
| 724 |
+
" else:\n",
|
| 725 |
+
" question = f\"Income statement analysis: {income_stmt_analysis}, \\\n",
|
| 726 |
+
" Balance sheet analysis: {balance_sheet_analysis}, \\\n",
|
| 727 |
+
" Cash flow analysis: {cash_flow_analysis}, \\\n",
|
| 728 |
+
" Synthesize the findings from the in-depth analysis of the income statement, balance sheet, and cash flow for the latest fiscal year. Highlight the core insights regarding the company's operational performance, financial stability, and cash management efficiency. Discuss the interrelations between revenue growth, cost management strategies, and their impact on profitability as revealed by the income statement. Incorporate the balance sheet's insights on financial structure, liquidity, and solvency to provide a comprehensive view of the company's financial health. Merge these with the cash flow analysis to illustrate the company's liquidity position, investment activities, and financing strategies. Conclude with a holistic assessment of the company's fiscal health, identifying strengths, potential risks, and strategic opportunities for growth and stability. Offer recommendations to address identified challenges and capitalize on the opportunities to enhance shareholder value in a single paragraph. Less than 150 words.\"\n",
|
| 729 |
+
"\n",
|
| 730 |
+
" answer = self.ask_question(question, 7, use_rag=False)\n",
|
| 731 |
+
" with open(cache_answer, \"w\") as f:\n",
|
| 732 |
+
" f.write(answer)\n",
|
| 733 |
+
" return {\"Income Statement Analysis\": income_stmt_analysis, \"Balance Sheet Analysis\": balance_sheet_analysis, \"Cash Flow Analysis\": cash_flow_analysis, \"Financial Summary\": answer}\n",
|
| 734 |
+
"\n",
|
| 735 |
+
"\n",
|
| 736 |
+
" def ask_question(self, question, section, table_str=None, use_rag=False):\n",
|
| 737 |
+
" if use_rag:\n",
|
| 738 |
+
" rag_chain = self.get_10k_rag(section)\n",
|
| 739 |
+
" if table_str:\n",
|
| 740 |
+
" prompt = f\"{self.system_prompt}\\n\\n{table_str}\\n\\nQuestion: {question}\"\n",
|
| 741 |
+
" else:\n",
|
| 742 |
+
" prompt = f\"{self.system_prompt}\\n\\nQuestion: {question}\"\n",
|
| 743 |
+
" answer = rag_chain.invoke(prompt)\n",
|
| 744 |
+
" else:\n",
|
| 745 |
+
" # 发送请求给OpenAI API使用指定的模型\n",
|
| 746 |
+
" section_text = self.get_10k_section(7)\n",
|
| 747 |
+
" if table_str:\n",
|
| 748 |
+
" prompt = f\"{self.system_prompt}\\n\\n{table_str}\\n\\nResource: {section_text}\\n\\nQuestion: {question}\"\n",
|
| 749 |
+
" else:\n",
|
| 750 |
+
" prompt = f\"{self.system_prompt}\\n\\nResource: {section_text}\\n\\nQuestion: {question}\"\n",
|
| 751 |
+
" \n",
|
| 752 |
+
" chat_completion = client.chat.completions.create(\n",
|
| 753 |
+
" messages=[\n",
|
| 754 |
+
" {\n",
|
| 755 |
+
" \"role\": \"user\",\n",
|
| 756 |
+
" \"content\": prompt.strip(),\n",
|
| 757 |
+
" }\n",
|
| 758 |
+
" ],\n",
|
| 759 |
+
" # model=\"gpt-4-1106-preview\",\n",
|
| 760 |
+
" model=llm,\n",
|
| 761 |
+
" temperature = 0,\n",
|
| 762 |
+
" max_tokens = 500,\n",
|
| 763 |
+
" # response_format={ \"type\": \"json_object\" },\n",
|
| 764 |
+
" )\n",
|
| 765 |
+
" answer = chat_completion.choices[0].message.content\n",
|
| 766 |
+
"\n",
|
| 767 |
+
" return answer\n",
|
| 768 |
+
"\n",
|
| 769 |
+
"\n",
|
| 770 |
+
"ra = ReportAnalysis(ticker_symbol)"
|
| 771 |
+
]
|
| 772 |
+
},
|
| 773 |
+
{
|
| 774 |
+
"cell_type": "code",
|
| 775 |
+
"execution_count": 99,
|
| 776 |
+
"metadata": {},
|
| 777 |
+
"outputs": [],
|
| 778 |
+
"source": [
|
| 779 |
+
"answer = ra.financial_summarization()"
|
| 780 |
+
]
|
| 781 |
+
},
|
| 782 |
+
{
|
| 783 |
+
"cell_type": "markdown",
|
| 784 |
+
"metadata": {},
|
| 785 |
+
"source": [
|
| 786 |
+
"### Resources to understand the financial report\n",
|
| 787 |
+
"1. income statement: https://online.hbs.edu/blog/post/income-statement-analysis\n",
|
| 788 |
+
"2. balance sheet: https://online.hbs.edu/blog/post/how-to-read-a-balance-sheet\n",
|
| 789 |
+
"3. cash flow statement: https://online.hbs.edu/blog/post/how-to-read-a-cash-flow-statement\n",
|
| 790 |
+
"4. Annual report: https://online.hbs.edu/blog/post/how-to-read-an-annual-report\n",
|
| 791 |
+
"\n",
|
| 792 |
+
"An annual report typically consists of:\n",
|
| 793 |
+
"1. Letters to shareholders: These documents provide a broad overview of the company’s activities and performance over the course of the year, as well as a reflection on its general business environment. An annual report usually includes a shareholder letter from the CEO or president, and may also contain letters from other key figures, such as the CFO.\n",
|
| 794 |
+
"2. [section 7] Management’s discussion and analysis (MD&A): This is a detailed analysis of the company’s performance, as conducted by its executives.\n",
|
| 795 |
+
"3. [section 8] Audited financial statements: These are financial documents that detail the company’s financial performance. Commonly included statements include balance sheets, cash flow statements, income statements, and equity statements.\n",
|
| 796 |
+
"4. [section 8] A summary of financial data: This refers to any notes or discussions that are pertinent to the financial statements listed above.\n",
|
| 797 |
+
"5. [section 8] Auditor’s report: This report describes whether the company has complied with generally accepted accounting principles (GAAP) in preparing its financial statements.\n",
|
| 798 |
+
"6. Accounting policies: This is an overview of the policies the company’s leadership team relied upon in preparing the annual report and financial statements.\n",
|
| 799 |
+
"\n",
|
| 800 |
+
"\n",
|
| 801 |
+
"Answer the following questions:\n",
|
| 802 |
+
"1. Whether it’s able to pay debts as they come due\n",
|
| 803 |
+
"2. Its profits and/or losses year over year\n",
|
| 804 |
+
"3. If and how it’s grown over time\n",
|
| 805 |
+
"4. What it requires to maintain or expand its business\n",
|
| 806 |
+
"5. Operational expenses compared to generated revenues"
|
| 807 |
+
]
|
| 808 |
+
},
|
| 809 |
+
{
|
| 810 |
+
"cell_type": "markdown",
|
| 811 |
+
"metadata": {},
|
| 812 |
+
"source": [
|
| 813 |
+
"## Build the pdf Report\n",
|
| 814 |
+
"\n",
|
| 815 |
+
"format: https://docs.reportlab.com/rml/tutorials/fund-reports-json-to-pdf/"
|
| 816 |
+
]
|
| 817 |
+
},
|
| 818 |
+
{
|
| 819 |
+
"cell_type": "code",
|
| 820 |
+
"execution_count": 100,
|
| 821 |
+
"metadata": {},
|
| 822 |
+
"outputs": [
|
| 823 |
+
{
|
| 824 |
+
"name": "stderr",
|
| 825 |
+
"output_type": "stream",
|
| 826 |
+
"text": [
|
| 827 |
+
"C:\\Users\\Boyu\\AppData\\Local\\Temp\\ipykernel_25956\\708270834.py:185: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n",
|
| 828 |
+
" df = df.applymap(convert_if_money)\n",
|
| 829 |
+
"C:\\Users\\Boyu\\AppData\\Local\\Temp\\ipykernel_25956\\708270834.py:203: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n",
|
| 830 |
+
" df = df.applymap(convert_if_money)\n"
|
| 831 |
+
]
|
| 832 |
+
}
|
| 833 |
+
],
|
| 834 |
+
"source": [
|
| 835 |
+
"from reportlab.lib import colors\n",
|
| 836 |
+
"from reportlab.lib import pagesizes\n",
|
| 837 |
+
"from reportlab.platypus import SimpleDocTemplate, Frame, Paragraph, Image, PageTemplate, FrameBreak, Spacer, Table, TableStyle, NextPageTemplate, PageBreak\n",
|
| 838 |
+
"from reportlab.lib.units import inch\n",
|
| 839 |
+
"from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle\n",
|
| 840 |
+
"from reportlab.lib.enums import TA_CENTER, TA_JUSTIFY, TA_LEFT\n",
|
| 841 |
+
"\n",
|
| 842 |
+
"\n",
|
| 843 |
+
"# 2. 创建PDF并插入图像\n",
|
| 844 |
+
"# 页面设置\n",
|
| 845 |
+
"page_width, page_height = pagesizes.A4\n",
|
| 846 |
+
"left_column_width = page_width * 2/3\n",
|
| 847 |
+
"right_column_width = page_width - left_column_width\n",
|
| 848 |
+
"margin = 4\n",
|
| 849 |
+
"\n",
|
| 850 |
+
"# 创建PDF文档路径\n",
|
| 851 |
+
"pdf_path = os.path.join(ra.project_dir, f\"{ticker_symbol}_report.pdf\")\n",
|
| 852 |
+
"doc = SimpleDocTemplate(pdf_path, pagesize=pagesizes.A4)\n",
|
| 853 |
+
"\n",
|
| 854 |
+
"# 定义两个栏位的Frame\n",
|
| 855 |
+
"frame_left = Frame(margin, margin, left_column_width-margin*2, page_height-margin*2, id='left')\n",
|
| 856 |
+
"frame_right = Frame(left_column_width, margin, right_column_width-margin*2, page_height-margin*2, id='right')\n",
|
| 857 |
+
"\n",
|
| 858 |
+
"# single_frame = Frame(margin, margin, page_width-margin*2, page_height-margin*2, id='single')\n",
|
| 859 |
+
"# single_column_layout = PageTemplate(id='OneCol', frames=[single_frame])\n",
|
| 860 |
+
"\n",
|
| 861 |
+
"left_column_width_p2 = (page_width-margin*3) // 2\n",
|
| 862 |
+
"right_column_width_p2 = left_column_width_p2\n",
|
| 863 |
+
"frame_left_p2 = Frame(margin, margin, left_column_width_p2-margin*2, page_height-margin*2, id='left')\n",
|
| 864 |
+
"frame_right_p2 = Frame(left_column_width_p2, margin, right_column_width_p2-margin*2, page_height-margin*2, id='right')\n",
|
| 865 |
+
"\n",
|
| 866 |
+
"# 创建PageTemplate,并添加到文档\n",
|
| 867 |
+
"page_template = PageTemplate(id='TwoColumns', frames=[frame_left, frame_right])\n",
|
| 868 |
+
"page_template_p2 = PageTemplate(id='TwoColumns_p2', frames=[frame_left_p2, frame_right_p2])\n",
|
| 869 |
+
"doc.addPageTemplates([page_template, page_template_p2])\n",
|
| 870 |
+
"\n",
|
| 871 |
+
"styles = getSampleStyleSheet()\n",
|
| 872 |
+
"\n",
|
| 873 |
+
"# 自定义样式\n",
|
| 874 |
+
"custom_style = ParagraphStyle(\n",
|
| 875 |
+
" name=\"Custom\",\n",
|
| 876 |
+
" parent=styles['Normal'],\n",
|
| 877 |
+
" fontName=\"Helvetica\",\n",
|
| 878 |
+
" fontSize=10,\n",
|
| 879 |
+
" # leading=15,\n",
|
| 880 |
+
" alignment=TA_JUSTIFY,\n",
|
| 881 |
+
")\n",
|
| 882 |
+
"\n",
|
| 883 |
+
"title_style = ParagraphStyle(\n",
|
| 884 |
+
" name=\"TitleCustom\",\n",
|
| 885 |
+
" parent=styles['Title'],\n",
|
| 886 |
+
" fontName=\"Helvetica-Bold\",\n",
|
| 887 |
+
" fontSize=16,\n",
|
| 888 |
+
" leading=20,\n",
|
| 889 |
+
" alignment=TA_LEFT,\n",
|
| 890 |
+
" spaceAfter=10,\n",
|
| 891 |
+
")\n",
|
| 892 |
+
"\n",
|
| 893 |
+
"subtitle_style = ParagraphStyle(\n",
|
| 894 |
+
" name=\"Subtitle\",\n",
|
| 895 |
+
" parent=styles['Heading2'],\n",
|
| 896 |
+
" fontName=\"Helvetica-Bold\",\n",
|
| 897 |
+
" fontSize=14,\n",
|
| 898 |
+
" leading=12,\n",
|
| 899 |
+
" alignment=TA_LEFT,\n",
|
| 900 |
+
" spaceAfter=6,\n",
|
| 901 |
+
")\n",
|
| 902 |
+
"\n",
|
| 903 |
+
"# 准备左栏和右栏内容\n",
|
| 904 |
+
"content = []\n",
|
| 905 |
+
"# 标题\n",
|
| 906 |
+
"content.append(Paragraph(f\"Equity Research Report: {ra.get_company_info()['Company Name']}\", title_style))\n",
|
| 907 |
+
"\n",
|
| 908 |
+
"# 子标题\n",
|
| 909 |
+
"content.append(Paragraph(\"Income Statement Analysis\", subtitle_style))\n",
|
| 910 |
+
"content.append(Paragraph(answer['Income Statement Analysis'], custom_style))\n",
|
| 911 |
+
"\n",
|
| 912 |
+
"content.append(Paragraph(\"Balance Sheet Analysis\", subtitle_style))\n",
|
| 913 |
+
"content.append(Paragraph(answer['Balance Sheet Analysis'], custom_style))\n",
|
| 914 |
+
"\n",
|
| 915 |
+
"content.append(Paragraph(\"Cashflow Analysis\", subtitle_style))\n",
|
| 916 |
+
"content.append(Paragraph(answer['Cash Flow Analysis'], custom_style))\n",
|
| 917 |
+
"\n",
|
| 918 |
+
"content.append(Paragraph(\"Summarization\", subtitle_style))\n",
|
| 919 |
+
"content.append(Paragraph(answer['Financial Summary'], custom_style))\n",
|
| 920 |
+
"\n",
|
| 921 |
+
"\n",
|
| 922 |
+
"content.append(FrameBreak()) # 用于从左栏跳到右栏\n",
|
| 923 |
+
"\n",
|
| 924 |
+
"table_style = TableStyle([\n",
|
| 925 |
+
" ('BACKGROUND', (0, 0), (-1, -1), colors.white),\n",
|
| 926 |
+
" ('BACKGROUND', (0, 0), (-1, 0), colors.white),\n",
|
| 927 |
+
" ('FONT', (0, 0), (-1, -1), 'Helvetica', 8),\n",
|
| 928 |
+
" ('FONT', (0, 0), (-1, 0), 'Helvetica-Bold', 12),\n",
|
| 929 |
+
" ('VALIGN', (0, 0), (-1, -1), 'MIDDLE'),\n",
|
| 930 |
+
" # 第一列左对齐\n",
|
| 931 |
+
" ('ALIGN', (0,1), (0,-1), 'LEFT'),\n",
|
| 932 |
+
" # 第二列右对齐\n",
|
| 933 |
+
" ('ALIGN', (1,1), (1,-1), 'RIGHT'),\n",
|
| 934 |
+
" # 标题栏下方添加横线\n",
|
| 935 |
+
" ('LINEBELOW', (0,0), (-1,0), 2, colors.black),\n",
|
| 936 |
+
"])\n",
|
| 937 |
+
"full_length = right_column_width-2*margin\n",
|
| 938 |
+
"\n",
|
| 939 |
+
"rating, _ = ra.get_analyst_recommendations()\n",
|
| 940 |
+
"target_price = ra.get_target_price()\n",
|
| 941 |
+
"if target_price is not None:\n",
|
| 942 |
+
" data = [\n",
|
| 943 |
+
" [\"Rating:\", rating.upper()],\n",
|
| 944 |
+
" [\"Target Price:\", f\"{target_price:.2f}\"]\n",
|
| 945 |
+
" ]\n",
|
| 946 |
+
"else:\n",
|
| 947 |
+
" data = [[\"Rating:\", rating.upper()]]\n",
|
| 948 |
+
"col_widths = [full_length//3*2, full_length//3]\n",
|
| 949 |
+
"table = Table(data, colWidths=col_widths)\n",
|
| 950 |
+
"table.setStyle(table_style)\n",
|
| 951 |
+
"content.append(table)\n",
|
| 952 |
+
"\n",
|
| 953 |
+
"# content.append(Paragraph(\"\", custom_style))\n",
|
| 954 |
+
"content.append(Spacer(1, 0.15*inch))\n",
|
| 955 |
+
"key_data = ra.get_key_data()\n",
|
| 956 |
+
"# 表格数据\n",
|
| 957 |
+
"data = [[\"Key data\", \"\"]]\n",
|
| 958 |
+
"data += [\n",
|
| 959 |
+
" [k, v] for k, v in key_data.items()\n",
|
| 960 |
+
"]\n",
|
| 961 |
+
"col_widths = [full_length//3*2, full_length//3]\n",
|
| 962 |
+
"table = Table(data, colWidths=col_widths)\n",
|
| 963 |
+
"table.setStyle(table_style)\n",
|
| 964 |
+
"content.append(table)\n",
|
| 965 |
+
"\n",
|
| 966 |
+
"\n",
|
| 967 |
+
"# 将Matplotlib图像添加到右栏\n",
|
| 968 |
+
"\n",
|
| 969 |
+
"# 历史股价\n",
|
| 970 |
+
"data = [[\"Share Performance\"]]\n",
|
| 971 |
+
"col_widths = [full_length]\n",
|
| 972 |
+
"table = Table(data, colWidths=col_widths)\n",
|
| 973 |
+
"table.setStyle(table_style)\n",
|
| 974 |
+
"content.append(table)\n",
|
| 975 |
+
"\n",
|
| 976 |
+
"plot_path = ra.get_stock_performance()\n",
|
| 977 |
+
"width = right_column_width\n",
|
| 978 |
+
"height = width//2\n",
|
| 979 |
+
"content.append(Image(plot_path, width=width, height=height))\n",
|
| 980 |
+
"\n",
|
| 981 |
+
"# 历史PE和EPS\n",
|
| 982 |
+
"data = [[\"PE & EPS\"]]\n",
|
| 983 |
+
"col_widths = [full_length]\n",
|
| 984 |
+
"table = Table(data, colWidths=col_widths)\n",
|
| 985 |
+
"table.setStyle(table_style)\n",
|
| 986 |
+
"content.append(table)\n",
|
| 987 |
+
"\n",
|
| 988 |
+
"plot_path = ra.get_pe_eps_performance()\n",
|
| 989 |
+
"width = right_column_width\n",
|
| 990 |
+
"height = width//2\n",
|
| 991 |
+
"content.append(Image(plot_path, width=width, height=height))\n",
|
| 992 |
+
"\n",
|
| 993 |
+
"\n",
|
| 994 |
+
"# 开始新的一页\n",
|
| 995 |
+
"content.append(NextPageTemplate('TwoColumns_p2'))\n",
|
| 996 |
+
"content.append(PageBreak())\n",
|
| 997 |
+
"\n",
|
| 998 |
+
"table_style2 = TableStyle([\n",
|
| 999 |
+
" ('BACKGROUND', (0, 0), (-1, -1), colors.white),\n",
|
| 1000 |
+
" ('BACKGROUND', (0, 0), (-1, 0), colors.white),\n",
|
| 1001 |
+
" ('FONT', (0, 0), (-1, -1), 'Helvetica', 6),\n",
|
| 1002 |
+
" ('FONT', (0, 0), (-1, 0), 'Helvetica-Bold', 10),\n",
|
| 1003 |
+
" ('VALIGN', (0, 0), (-1, -1), 'MIDDLE'),\n",
|
| 1004 |
+
" # 第一列左对齐\n",
|
| 1005 |
+
" ('ALIGN', (0,1), (0,-1), 'LEFT'),\n",
|
| 1006 |
+
" # 第二列右对齐\n",
|
| 1007 |
+
" ('ALIGN', (1,1), (1,-1), 'RIGHT'),\n",
|
| 1008 |
+
" # 标题栏下方添加横线\n",
|
| 1009 |
+
" ('LINEBELOW', (0,0), (-1,0), 2, colors.black),\n",
|
| 1010 |
+
" # 表格最下方添加横线\n",
|
| 1011 |
+
" ('LINEBELOW', (0,-1), (-1,-1), 2, colors.black),\n",
|
| 1012 |
+
"])\n",
|
| 1013 |
+
"\n",
|
| 1014 |
+
"\n",
|
| 1015 |
+
"# 第二页及之后内容,使用单栏布局\n",
|
| 1016 |
+
"df = ra.get_income_stmt()\n",
|
| 1017 |
+
"df = df[df.columns[:3]]\n",
|
| 1018 |
+
"def convert_if_money(value):\n",
|
| 1019 |
+
" if np.abs(value) >= 1000000:\n",
|
| 1020 |
+
" return value / 1000000\n",
|
| 1021 |
+
" else:\n",
|
| 1022 |
+
" return value\n",
|
| 1023 |
+
"\n",
|
| 1024 |
+
"# 应用转换函数到DataFrame的每列\n",
|
| 1025 |
+
"df = df.applymap(convert_if_money)\n",
|
| 1026 |
+
"\n",
|
| 1027 |
+
"df.columns = [col.strftime('%Y') for col in df.columns]\n",
|
| 1028 |
+
"df.reset_index(inplace=True)\n",
|
| 1029 |
+
"currency = ra.info['currency']\n",
|
| 1030 |
+
"df.rename(columns={'index': f'FY ({currency} mn)'}, inplace=True) # 可选:重命名索引列为“序号”\n",
|
| 1031 |
+
"table_data = [[\"Income Statement\"]]\n",
|
| 1032 |
+
"table_data += [df.columns.to_list()] + df.values.tolist()\n",
|
| 1033 |
+
"\n",
|
| 1034 |
+
"table = Table(table_data)\n",
|
| 1035 |
+
"table.setStyle(table_style2)\n",
|
| 1036 |
+
"content.append(table)\n",
|
| 1037 |
+
"\n",
|
| 1038 |
+
"content.append(FrameBreak()) # 用于从左栏跳到右栏\n",
|
| 1039 |
+
"\n",
|
| 1040 |
+
"df = ra.get_cash_flow()\n",
|
| 1041 |
+
"df = df[df.columns[:3]]\n",
|
| 1042 |
+
"\n",
|
| 1043 |
+
"df = df.applymap(convert_if_money)\n",
|
| 1044 |
+
"\n",
|
| 1045 |
+
"df.columns = [col.strftime('%Y') for col in df.columns]\n",
|
| 1046 |
+
"df.reset_index(inplace=True)\n",
|
| 1047 |
+
"currency = ra.info['currency']\n",
|
| 1048 |
+
"df.rename(columns={'index': f'FY ({currency} mn)'}, inplace=True) # 可选:重命名索引列为“序号”\n",
|
| 1049 |
+
"table_data = [[\"Cash Flow Sheet\"]]\n",
|
| 1050 |
+
"table_data += [df.columns.to_list()] + df.values.tolist()\n",
|
| 1051 |
+
"\n",
|
| 1052 |
+
"table = Table(table_data)\n",
|
| 1053 |
+
"table.setStyle(table_style2)\n",
|
| 1054 |
+
"content.append(table)\n",
|
| 1055 |
+
"# content.append(Paragraph('This is a single column on the second page', custom_style))\n",
|
| 1056 |
+
"# content.append(Spacer(1, 0.2*inch))\n",
|
| 1057 |
+
"# content.append(Paragraph('More content in the single column.', custom_style))\n",
|
| 1058 |
+
"\n",
|
| 1059 |
+
"# 构建PDF文档\n",
|
| 1060 |
+
"doc.build(content)\n"
|
| 1061 |
+
]
|
| 1062 |
+
}
|
| 1063 |
+
],
|
| 1064 |
+
"metadata": {
|
| 1065 |
+
"kernelspec": {
|
| 1066 |
+
"display_name": "base",
|
| 1067 |
+
"language": "python",
|
| 1068 |
+
"name": "python3"
|
| 1069 |
+
},
|
| 1070 |
+
"language_info": {
|
| 1071 |
+
"codemirror_mode": {
|
| 1072 |
+
"name": "ipython",
|
| 1073 |
+
"version": 3
|
| 1074 |
+
},
|
| 1075 |
+
"file_extension": ".py",
|
| 1076 |
+
"mimetype": "text/x-python",
|
| 1077 |
+
"name": "python",
|
| 1078 |
+
"nbconvert_exporter": "python",
|
| 1079 |
+
"pygments_lexer": "ipython3",
|
| 1080 |
+
"version": "3.11.8"
|
| 1081 |
+
}
|
| 1082 |
+
},
|
| 1083 |
+
"nbformat": 4,
|
| 1084 |
+
"nbformat_minor": 2
|
| 1085 |
+
}
|
fingpt/FinGPT_FinancialReportAnalysis/utils/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils.earning_calls import get_earnings_transcript, extract_speakers
|
| 2 |
+
from utils.rag import Raptor
|
fingpt/FinGPT_FinancialReportAnalysis/utils/earning_calls.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
| 2 |
+
import requests
|
| 3 |
+
import json
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
import re
|
| 6 |
+
from typing import List
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def correct_date(yr, dt):
|
| 10 |
+
"""Some transcripts have incorrect date, correcting it
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
yr (int): actual
|
| 14 |
+
dt (datetime): given date
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
datetime: corrected date
|
| 18 |
+
"""
|
| 19 |
+
dt = datetime.strptime(dt, "%Y-%m-%d %H:%M:%S")
|
| 20 |
+
if dt.year != yr:
|
| 21 |
+
dt = dt.replace(year=yr)
|
| 22 |
+
return dt.strftime("%Y-%m-%d %H:%M:%S")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def extract_speakers(cont: str) -> List[str]:
|
| 26 |
+
"""Extract the list of speakers
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
cont (str): transcript content
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
List[str]: list of speakers
|
| 33 |
+
"""
|
| 34 |
+
pattern = re.compile(r"\n(.*?):")
|
| 35 |
+
matches = pattern.findall(cont)
|
| 36 |
+
|
| 37 |
+
return list(set(matches))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(2))
|
| 41 |
+
def get_earnings_transcript(quarter: str, ticker: str, year: int):
|
| 42 |
+
"""Get the earnings transcripts
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
quarter (str)
|
| 46 |
+
ticker (str)
|
| 47 |
+
year (int)
|
| 48 |
+
"""
|
| 49 |
+
response = requests.get(
|
| 50 |
+
f"https://discountingcashflows.com/api/transcript/{ticker}/{quarter}/{year}/",
|
| 51 |
+
auth=("user", "pass"),
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
resp_text = json.loads(response.text)
|
| 55 |
+
# speakers_list = extract_speakers(resp_text[0]["content"])
|
| 56 |
+
corrected_date = correct_date(resp_text[0]["year"], resp_text[0]["date"])
|
| 57 |
+
resp_text[0]["date"] = corrected_date
|
| 58 |
+
return resp_text[0]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# from utils import get_earnings_transcript
|
| 63 |
+
|
| 64 |
+
# quarter = "Q4"
|
| 65 |
+
# ticker = "AAPL"
|
| 66 |
+
# year = 2023
|
| 67 |
+
# resp_dict, speakers_list = get_earnings_transcript(
|
| 68 |
+
# quarter, ticker, year
|
| 69 |
+
# )
|
fingpt/FinGPT_FinancialReportAnalysis/utils/format_pdf.py
ADDED
|
File without changes
|
fingpt/FinGPT_FinancialReportAnalysis/utils/rag.py
ADDED
|
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import umap
|
| 6 |
+
from langchain.prompts import ChatPromptTemplate
|
| 7 |
+
from langchain_core.output_parsers import StrOutputParser
|
| 8 |
+
from sklearn.mixture import GaussianMixture
|
| 9 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
RANDOM_SEED = 224 # Fixed seed for reproducibility
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Raptor:
|
| 16 |
+
def __init__(self, model, embed):
|
| 17 |
+
self.model = model
|
| 18 |
+
self.embd = embed
|
| 19 |
+
|
| 20 |
+
def global_cluster_embeddings(
|
| 21 |
+
self,
|
| 22 |
+
embeddings: np.ndarray,
|
| 23 |
+
dim: int,
|
| 24 |
+
n_neighbors: Optional[int] = None,
|
| 25 |
+
metric: str = "cosine",
|
| 26 |
+
) -> np.ndarray:
|
| 27 |
+
"""
|
| 28 |
+
Perform global dimensionality reduction on the embeddings using UMAP.
|
| 29 |
+
|
| 30 |
+
Parameters:
|
| 31 |
+
- embeddings: The input embeddings as a numpy array.
|
| 32 |
+
- dim: The target dimensionality for the reduced space.
|
| 33 |
+
- n_neighbors: Optional; the number of neighbors to consider for each point.
|
| 34 |
+
If not provided, it defaults to the square root of the number of embeddings.
|
| 35 |
+
- metric: The distance metric to use for UMAP.
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
- A numpy array of the embeddings reduced to the specified dimensionality.
|
| 39 |
+
"""
|
| 40 |
+
if n_neighbors is None:
|
| 41 |
+
n_neighbors = int((len(embeddings) - 1) ** 0.5)
|
| 42 |
+
return umap.UMAP(
|
| 43 |
+
n_neighbors=n_neighbors, n_components=dim, metric=metric
|
| 44 |
+
).fit_transform(embeddings)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def local_cluster_embeddings(
|
| 48 |
+
self, embeddings: np.ndarray, dim: int, num_neighbors: int = 10, metric: str = "cosine"
|
| 49 |
+
) -> np.ndarray:
|
| 50 |
+
"""
|
| 51 |
+
Perform local dimensionality reduction on the embeddings using UMAP, typically after global clustering.
|
| 52 |
+
|
| 53 |
+
Parameters:
|
| 54 |
+
- embeddings: The input embeddings as a numpy array.
|
| 55 |
+
- dim: The target dimensionality for the reduced space.
|
| 56 |
+
- num_neighbors: The number of neighbors to consider for each point.
|
| 57 |
+
- metric: The distance metric to use for UMAP.
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
- A numpy array of the embeddings reduced to the specified dimensionality.
|
| 61 |
+
"""
|
| 62 |
+
return umap.UMAP(
|
| 63 |
+
n_neighbors=num_neighbors, n_components=dim, metric=metric
|
| 64 |
+
).fit_transform(embeddings)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def get_optimal_clusters(
|
| 68 |
+
self, embeddings: np.ndarray, max_clusters: int = 50, random_state: int = RANDOM_SEED
|
| 69 |
+
) -> int:
|
| 70 |
+
"""
|
| 71 |
+
Determine the optimal number of clusters using the Bayesian Information Criterion (BIC) with a Gaussian Mixture Model.
|
| 72 |
+
|
| 73 |
+
Parameters:
|
| 74 |
+
- embeddings: The input embeddings as a numpy array.
|
| 75 |
+
- max_clusters: The maximum number of clusters to consider.
|
| 76 |
+
- random_state: Seed for reproducibility.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
- An integer representing the optimal number of clusters found.
|
| 80 |
+
"""
|
| 81 |
+
max_clusters = min(max_clusters, len(embeddings))
|
| 82 |
+
n_clusters = np.arange(1, max_clusters)
|
| 83 |
+
bics = []
|
| 84 |
+
for n in n_clusters:
|
| 85 |
+
gm = GaussianMixture(n_components=n, random_state=random_state)
|
| 86 |
+
gm.fit(embeddings)
|
| 87 |
+
bics.append(gm.bic(embeddings))
|
| 88 |
+
return n_clusters[np.argmin(bics)]
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def GMM_cluster(self, embeddings: np.ndarray, threshold: float, random_state: int = 0):
|
| 92 |
+
"""
|
| 93 |
+
Cluster embeddings using a Gaussian Mixture Model (GMM) based on a probability threshold.
|
| 94 |
+
|
| 95 |
+
Parameters:
|
| 96 |
+
- embeddings: The input embeddings as a numpy array.
|
| 97 |
+
- threshold: The probability threshold for assigning an embedding to a cluster.
|
| 98 |
+
- random_state: Seed for reproducibility.
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
- A tuple containing the cluster labels and the number of clusters determined.
|
| 102 |
+
"""
|
| 103 |
+
n_clusters = self.get_optimal_clusters(embeddings)
|
| 104 |
+
gm = GaussianMixture(n_components=n_clusters, random_state=random_state)
|
| 105 |
+
gm.fit(embeddings)
|
| 106 |
+
probs = gm.predict_proba(embeddings)
|
| 107 |
+
labels = [np.where(prob > threshold)[0] for prob in probs]
|
| 108 |
+
return labels, n_clusters
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def perform_clustering(
|
| 112 |
+
self,
|
| 113 |
+
embeddings: np.ndarray,
|
| 114 |
+
dim: int,
|
| 115 |
+
threshold: float,
|
| 116 |
+
) -> List[np.ndarray]:
|
| 117 |
+
"""
|
| 118 |
+
Perform clustering on the embeddings by first reducing their dimensionality globally, then clustering
|
| 119 |
+
using a Gaussian Mixture Model, and finally performing local clustering within each global cluster.
|
| 120 |
+
|
| 121 |
+
Parameters:
|
| 122 |
+
- embeddings: The input embeddings as a numpy array.
|
| 123 |
+
- dim: The target dimensionality for UMAP reduction.
|
| 124 |
+
- threshold: The probability threshold for assigning an embedding to a cluster in GMM.
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
- A list of numpy arrays, where each array contains the cluster IDs for each embedding.
|
| 128 |
+
"""
|
| 129 |
+
if len(embeddings) <= dim + 1:
|
| 130 |
+
# Avoid clustering when there's insufficient data
|
| 131 |
+
return [np.array([0]) for _ in range(len(embeddings))]
|
| 132 |
+
|
| 133 |
+
# Global dimensionality reduction
|
| 134 |
+
reduced_embeddings_global = self.global_cluster_embeddings(embeddings, dim)
|
| 135 |
+
# Global clustering
|
| 136 |
+
global_clusters, n_global_clusters = self.GMM_cluster(
|
| 137 |
+
reduced_embeddings_global, threshold
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
all_local_clusters = [np.array([]) for _ in range(len(embeddings))]
|
| 141 |
+
total_clusters = 0
|
| 142 |
+
|
| 143 |
+
# Iterate through each global cluster to perform local clustering
|
| 144 |
+
for i in range(n_global_clusters):
|
| 145 |
+
# Extract embeddings belonging to the current global cluster
|
| 146 |
+
global_cluster_embeddings_ = embeddings[
|
| 147 |
+
np.array([i in gc for gc in global_clusters])
|
| 148 |
+
]
|
| 149 |
+
|
| 150 |
+
if len(global_cluster_embeddings_) == 0:
|
| 151 |
+
continue
|
| 152 |
+
if len(global_cluster_embeddings_) <= dim + 1:
|
| 153 |
+
# Handle small clusters with direct assignment
|
| 154 |
+
local_clusters = [np.array([0]) for _ in global_cluster_embeddings_]
|
| 155 |
+
n_local_clusters = 1
|
| 156 |
+
else:
|
| 157 |
+
# Local dimensionality reduction and clustering
|
| 158 |
+
reduced_embeddings_local = self.local_cluster_embeddings(
|
| 159 |
+
global_cluster_embeddings_, dim
|
| 160 |
+
)
|
| 161 |
+
local_clusters, n_local_clusters = self.GMM_cluster(
|
| 162 |
+
reduced_embeddings_local, threshold
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# Assign local cluster IDs, adjusting for total clusters already processed
|
| 166 |
+
for j in range(n_local_clusters):
|
| 167 |
+
local_cluster_embeddings_ = global_cluster_embeddings_[
|
| 168 |
+
np.array([j in lc for lc in local_clusters])
|
| 169 |
+
]
|
| 170 |
+
indices = np.where(
|
| 171 |
+
(embeddings == local_cluster_embeddings_[:, None]).all(-1)
|
| 172 |
+
)[1]
|
| 173 |
+
for idx in indices:
|
| 174 |
+
all_local_clusters[idx] = np.append(
|
| 175 |
+
all_local_clusters[idx], j + total_clusters
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
total_clusters += n_local_clusters
|
| 179 |
+
|
| 180 |
+
return all_local_clusters
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
### --- Our code below --- ###
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def embed(self, texts):
|
| 187 |
+
"""
|
| 188 |
+
Generate embeddings for a list of text documents.
|
| 189 |
+
|
| 190 |
+
This function assumes the existence of an `embd` object with a method `embed_documents`
|
| 191 |
+
that takes a list of texts and returns their embeddings.
|
| 192 |
+
|
| 193 |
+
Parameters:
|
| 194 |
+
- texts: List[str], a list of text documents to be embedded.
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
- numpy.ndarray: An array of embeddings for the given text documents.
|
| 198 |
+
"""
|
| 199 |
+
text_embeddings = self.embd.embed_documents(texts)
|
| 200 |
+
text_embeddings_np = np.array(text_embeddings)
|
| 201 |
+
return text_embeddings_np
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def embed_cluster_texts(self, texts):
|
| 205 |
+
"""
|
| 206 |
+
Embeds a list of texts and clusters them, returning a DataFrame with texts, their embeddings, and cluster labels.
|
| 207 |
+
|
| 208 |
+
This function combines embedding generation and clustering into a single step. It assumes the existence
|
| 209 |
+
of a previously defined `perform_clustering` function that performs clustering on the embeddings.
|
| 210 |
+
|
| 211 |
+
Parameters:
|
| 212 |
+
- texts: List[str], a list of text documents to be processed.
|
| 213 |
+
|
| 214 |
+
Returns:
|
| 215 |
+
- pandas.DataFrame: A DataFrame containing the original texts, their embeddings, and the assigned cluster labels.
|
| 216 |
+
"""
|
| 217 |
+
text_embeddings_np = self.embed(texts) # Generate embeddings
|
| 218 |
+
cluster_labels = self.perform_clustering(
|
| 219 |
+
text_embeddings_np, 10, 0.1
|
| 220 |
+
) # Perform clustering on the embeddings
|
| 221 |
+
df = pd.DataFrame() # Initialize a DataFrame to store the results
|
| 222 |
+
df["text"] = texts # Store original texts
|
| 223 |
+
df["embd"] = list(text_embeddings_np) # Store embeddings as a list in the DataFrame
|
| 224 |
+
df["cluster"] = cluster_labels # Store cluster labels
|
| 225 |
+
return df
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def fmt_txt(self, df: pd.DataFrame) -> str:
|
| 229 |
+
"""
|
| 230 |
+
Formats the text documents in a DataFrame into a single string.
|
| 231 |
+
|
| 232 |
+
Parameters:
|
| 233 |
+
- df: DataFrame containing the 'text' column with text documents to format.
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
- A single string where all text documents are joined by a specific delimiter.
|
| 237 |
+
"""
|
| 238 |
+
unique_txt = df["text"].tolist()
|
| 239 |
+
return "--- --- \n --- --- ".join(unique_txt)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def embed_cluster_summarize_texts(
|
| 243 |
+
self, texts: List[str], level: int
|
| 244 |
+
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
| 245 |
+
"""
|
| 246 |
+
Embeds, clusters, and summarizes a list of texts. This function first generates embeddings for the texts,
|
| 247 |
+
clusters them based on similarity, expands the cluster assignments for easier processing, and then summarizes
|
| 248 |
+
the content within each cluster.
|
| 249 |
+
|
| 250 |
+
Parameters:
|
| 251 |
+
- texts: A list of text documents to be processed.
|
| 252 |
+
- level: An integer parameter that could define the depth or detail of processing.
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
- Tuple containing two DataFrames:
|
| 256 |
+
1. The first DataFrame (`df_clusters`) includes the original texts, their embeddings, and cluster assignments.
|
| 257 |
+
2. The second DataFrame (`df_summary`) contains summaries for each cluster, the specified level of detail,
|
| 258 |
+
and the cluster identifiers.
|
| 259 |
+
"""
|
| 260 |
+
|
| 261 |
+
# Embed and cluster the texts, resulting in a DataFrame with 'text', 'embd', and 'cluster' columns
|
| 262 |
+
df_clusters = self.embed_cluster_texts(texts)
|
| 263 |
+
|
| 264 |
+
# Prepare to expand the DataFrame for easier manipulation of clusters
|
| 265 |
+
expanded_list = []
|
| 266 |
+
|
| 267 |
+
# Expand DataFrame entries to document-cluster pairings for straightforward processing
|
| 268 |
+
for index, row in df_clusters.iterrows():
|
| 269 |
+
for cluster in row["cluster"]:
|
| 270 |
+
expanded_list.append(
|
| 271 |
+
{"text": row["text"], "embd": row["embd"], "cluster": cluster}
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# Create a new DataFrame from the expanded list
|
| 275 |
+
expanded_df = pd.DataFrame(expanded_list)
|
| 276 |
+
|
| 277 |
+
# Retrieve unique cluster identifiers for processing
|
| 278 |
+
all_clusters = expanded_df["cluster"].unique()
|
| 279 |
+
|
| 280 |
+
print(f"--Generated {len(all_clusters)} clusters--")
|
| 281 |
+
|
| 282 |
+
# Summarization
|
| 283 |
+
template = """Here is a sub-set of LangChain Expression Langauge doc.
|
| 284 |
+
|
| 285 |
+
LangChain Expression Langauge provides a way to compose chain in LangChain.
|
| 286 |
+
|
| 287 |
+
Give a detailed summary of the documentation provided.
|
| 288 |
+
|
| 289 |
+
Documentation:
|
| 290 |
+
{context}
|
| 291 |
+
"""
|
| 292 |
+
prompt = ChatPromptTemplate.from_template(template)
|
| 293 |
+
chain = prompt | self.model | StrOutputParser()
|
| 294 |
+
|
| 295 |
+
# Format text within each cluster for summarization
|
| 296 |
+
summaries = []
|
| 297 |
+
for i in all_clusters:
|
| 298 |
+
df_cluster = expanded_df[expanded_df["cluster"] == i]
|
| 299 |
+
formatted_txt = self.fmt_txt(df_cluster)
|
| 300 |
+
summaries.append(chain.invoke({"context": formatted_txt}))
|
| 301 |
+
|
| 302 |
+
# Create a DataFrame to store summaries with their corresponding cluster and level
|
| 303 |
+
df_summary = pd.DataFrame(
|
| 304 |
+
{
|
| 305 |
+
"summaries": summaries,
|
| 306 |
+
"level": [level] * len(summaries),
|
| 307 |
+
"cluster": list(all_clusters),
|
| 308 |
+
}
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
return df_clusters, df_summary
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def recursive_embed_cluster_summarize(
|
| 315 |
+
self, texts: List[str], level: int = 1, n_levels: int = 3
|
| 316 |
+
) -> Dict[int, Tuple[pd.DataFrame, pd.DataFrame]]:
|
| 317 |
+
"""
|
| 318 |
+
Recursively embeds, clusters, and summarizes texts up to a specified level or until
|
| 319 |
+
the number of unique clusters becomes 1, storing the results at each level.
|
| 320 |
+
|
| 321 |
+
Parameters:
|
| 322 |
+
- texts: List[str], texts to be processed.
|
| 323 |
+
- level: int, current recursion level (starts at 1).
|
| 324 |
+
- n_levels: int, maximum depth of recursion.
|
| 325 |
+
|
| 326 |
+
Returns:
|
| 327 |
+
- Dict[int, Tuple[pd.DataFrame, pd.DataFrame]], a dictionary where keys are the recursion
|
| 328 |
+
levels and values are tuples containing the clusters DataFrame and summaries DataFrame at that level.
|
| 329 |
+
"""
|
| 330 |
+
results = {} # Dictionary to store results at each level
|
| 331 |
+
|
| 332 |
+
# Perform embedding, clustering, and summarization for the current level
|
| 333 |
+
df_clusters, df_summary = self.embed_cluster_summarize_texts(texts, level)
|
| 334 |
+
|
| 335 |
+
# Store the results of the current level
|
| 336 |
+
results[level] = (df_clusters, df_summary)
|
| 337 |
+
|
| 338 |
+
# Determine if further recursion is possible and meaningful
|
| 339 |
+
unique_clusters = df_summary["cluster"].nunique()
|
| 340 |
+
if level < n_levels and unique_clusters > 1:
|
| 341 |
+
# Use summaries as the input texts for the next level of recursion
|
| 342 |
+
new_texts = df_summary["summaries"].tolist()
|
| 343 |
+
next_level_results = self.recursive_embed_cluster_summarize(
|
| 344 |
+
new_texts, level + 1, n_levels
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
# Merge the results from the next level into the current results dictionary
|
| 348 |
+
results.update(next_level_results)
|
| 349 |
+
|
| 350 |
+
return results
|
| 351 |
+
|
| 352 |
+
def text_spliter(self, text, chunk_size_tok=2000, level=1, n_levels=3):
|
| 353 |
+
"""
|
| 354 |
+
Parameters:
|
| 355 |
+
- text: str, text to be processed.
|
| 356 |
+
- chunk_size_tok: int, size of each chunk in tokens.
|
| 357 |
+
- level: int, current recursion level (starts at 1).
|
| 358 |
+
- n_levels: int, maximum depth of recursion.
|
| 359 |
+
Returns:
|
| 360 |
+
- List[str], all texts after recursive embedding, clustering, and summarization.
|
| 361 |
+
"""
|
| 362 |
+
if text is None:
|
| 363 |
+
raise ValueError("Text cannot be None.")
|
| 364 |
+
|
| 365 |
+
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
| 366 |
+
chunk_size=chunk_size_tok, chunk_overlap=0
|
| 367 |
+
)
|
| 368 |
+
texts_split = text_splitter.split_text(text)
|
| 369 |
+
if texts_split is None or len(texts_split) == 0:
|
| 370 |
+
raise ValueError("Text splitting did not produce any text chunks.")
|
| 371 |
+
|
| 372 |
+
results = self.recursive_embed_cluster_summarize(texts_split, level=level, n_levels=n_levels)
|
| 373 |
+
if results is None:
|
| 374 |
+
raise ValueError("Recursive embedding and clustering did not produce any results.")
|
| 375 |
+
|
| 376 |
+
all_texts = texts_split.copy()
|
| 377 |
+
|
| 378 |
+
for level in sorted(results.keys()):
|
| 379 |
+
# Extract summaries from the current level's DataFrame
|
| 380 |
+
if results[level] is None or len(results[level]) != 2:
|
| 381 |
+
raise ValueError(f"Unexpected results format at level {level}.")
|
| 382 |
+
summaries = results[level][1]["summaries"].tolist()
|
| 383 |
+
if summaries is None or len(summaries) == 0:
|
| 384 |
+
raise ValueError(f"Level {level} did not produce any summaries.")
|
| 385 |
+
# Extend all_texts with the summaries from the current level
|
| 386 |
+
all_texts.extend(summaries)
|
| 387 |
+
|
| 388 |
+
return all_texts
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
# def text_spliter(self, text, chunk_size_tok=2000, level=1, n_levels=3):
|
| 392 |
+
# - """
|
| 393 |
+
# - Parameters:
|
| 394 |
+
# - - texts: List[str], texts to be processed.
|
| 395 |
+
# - - level: int, current recursion level (starts at 1).
|
| 396 |
+
# - - n_levels: int, maximum depth of recursion.
|
| 397 |
+
# - """
|
| 398 |
+
# - text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
| 399 |
+
# - chunk_size=chunk_size_tok, chunk_overlap=0
|
| 400 |
+
# - )
|
| 401 |
+
# - texts_split = text_splitter.split_text(text)
|
| 402 |
+
# - results = self.recursive_embed_cluster_summarize(texts_split, level=level, n_levels=n_levels)
|
| 403 |
+
# -
|
| 404 |
+
# - all_texts = texts_split.copy()
|
| 405 |
+
# -
|
| 406 |
+
# - for level in sorted(results.keys()):
|
| 407 |
+
# - # Extract summaries from the current level's DataFrame
|
| 408 |
+
# - summaries = results[level][1]["summaries"].tolist()
|
| 409 |
+
# - # Extend all_texts with the summaries from the current level
|
| 410 |
+
# - all_texts.extend(summaries)
|
| 411 |
+
# -
|
| 412 |
+
# - return all_texts
|
fingpt/FinGPT_Forecaster/AAAI-Good-Data/README.md
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
## What is FinGPT-Forecaster?
|
| 3 |
+
- FinGPT-Forecaster takes market news and optional basic financials related to the specified company from the past few weeks as input and responds with the company's **positive developments** and **potential concerns**. Then it gives out a **prediction** of stock price movement for the coming week and its **analysis** summary.
|
| 4 |
+
- FinGPT-Forecaster is finetuned on Llama-2-7b-chat-hf with LoRA on the past year's DOW30 market data. But also has shown great generalization ability on other ticker symbols.
|
| 5 |
+
- FinGPT-Forecaster is an easy-to-deploy junior robo-advisor, a milestone towards our goal.
|
| 6 |
+
|
| 7 |
+
## Try out the demo!
|
| 8 |
+
|
| 9 |
+
Try our demo at <https://huggingface.co/spaces/FinGPT/FinGPT-Forecaster>
|
| 10 |
+
|
| 11 |
+

|
| 12 |
+
|
| 13 |
+
Enter the following inputs:
|
| 14 |
+
|
| 15 |
+
1) ticker symbol (e.g. AAPL, MSFT, NVDA)
|
| 16 |
+
2) the day from which you want the prediction to happen (yyyy-mm-dd)
|
| 17 |
+
3) the number of past weeks where market news are retrieved
|
| 18 |
+
4) whether to add latest basic financials as additional information
|
| 19 |
+
|
| 20 |
+
Then, click Submit!You'll get a response like this
|
| 21 |
+
|
| 22 |
+

|
| 23 |
+
|
| 24 |
+
This is just a demo showing what this model is capable of. Results inferred from randomly chosen news can be strongly biased.
|
| 25 |
+
For more detailed and customized usage, scroll down and continue your reading.
|
| 26 |
+
|
| 27 |
+
## Deploy FinGPT-Forecaster
|
| 28 |
+
|
| 29 |
+
We have released our FinGPT-Forecaster trained on DOW30 market data from 2022-12-30 to 2023-9-1 on HuggingFace: [fingpt-forecaster_dow30_llama2-7b_lora](https://huggingface.co/FinGPT/fingpt-forecaster_dow30_llama2-7b_lora)
|
| 30 |
+
|
| 31 |
+
We have most of the key requirements in `requirements.txt`. Before you start, do `pip install -r requirements.txt`. Then you can refer to `demo.ipynb` for our deployment and evaluation script.
|
| 32 |
+
|
| 33 |
+
First let's load the model:
|
| 34 |
+
|
| 35 |
+
```
|
| 36 |
+
from datasets import load_dataset
|
| 37 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 38 |
+
from peft import PeftModel
|
| 39 |
+
import torch
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 43 |
+
'meta-llama/Llama-2-7b-chat-hf',
|
| 44 |
+
trust_remote_code=True,
|
| 45 |
+
device_map="auto",
|
| 46 |
+
torch_dtype=torch.float16, # optional if you have enough VRAM
|
| 47 |
+
)
|
| 48 |
+
tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf')
|
| 49 |
+
|
| 50 |
+
model = PeftModel.from_pretrained(base_model, 'FinGPT/fingpt-forecaster_dow30_llama2-7b_lora')
|
| 51 |
+
model = model.eval()
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
Then you are ready to go, prepare your prompt with news & stock price movements in llama format (which we'll mention in the next section), and generate your own forecasting results!
|
| 55 |
+
```
|
| 56 |
+
B_INST, E_INST = "[INST]", "[/INST]"
|
| 57 |
+
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
| 58 |
+
|
| 59 |
+
prompt = B_INST + B_SYS + {SYSTEM_PROMPT} + E_SYS + {YOUR_PROMPT} + E_INST
|
| 60 |
+
inputs = tokenizer(
|
| 61 |
+
prompt, return_tensors='pt'
|
| 62 |
+
)
|
| 63 |
+
inputs = {key: value.to(model.device) for key, value in inputs.items()}
|
| 64 |
+
|
| 65 |
+
res = model.generate(
|
| 66 |
+
**inputs, max_length=4096, do_sample=True,
|
| 67 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 68 |
+
use_cache=True
|
| 69 |
+
)
|
| 70 |
+
output = tokenizer.decode(res[0], skip_special_tokens=True)
|
| 71 |
+
answer = re.sub(r'.*\[/INST\]\s*', '', output, flags=re.DOTALL) # don't forget to import re
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
## Data Preparation
|
| 75 |
+
Company profile & Market news & Basic financials & Stock prices are retrieved using **yfinance & finnhub**.
|
| 76 |
+
|
| 77 |
+
Prompts used are organized as below:
|
| 78 |
+
|
| 79 |
+
```
|
| 80 |
+
SYSTEM_PROMPT = "You are a seasoned stock market analyst. Your task is to list the positive developments and potential concerns for companies based on relevant news and basic financials from the past weeks, then provide an analysis and prediction for the companies' stock price movement for the upcoming week. Your answer format should be as follows:\n\n[Positive Developments]:\n1. ...\n\n[Potential Concerns]:\n1. ...\n\n[Prediction & Analysis]:\n...\n"
|
| 81 |
+
|
| 82 |
+
prompt = """
|
| 83 |
+
[Company Introduction]:
|
| 84 |
+
|
| 85 |
+
{name} is a leading entity in the {finnhubIndustry} sector. Incorporated and publicly traded since {ipo}, the company has established its reputation as one of the key players in the market. As of today, {name} has a market capitalization of {marketCapitalization:.2f} in {currency}, with {shareOutstanding:.2f} shares outstanding. {name} operates primarily in the {country}, trading under the ticker {ticker} on the {exchange}. As a dominant force in the {finnhubIndustry} space, the company continues to innovate and drive progress within the industry.
|
| 86 |
+
|
| 87 |
+
From {startDate} to {endDate}, {name}'s stock price {increase/decrease} from {startPrice} to {endPrice}. Company news during this period are listed below:
|
| 88 |
+
|
| 89 |
+
[Headline]: ...
|
| 90 |
+
[Summary]: ...
|
| 91 |
+
|
| 92 |
+
[Headline]: ...
|
| 93 |
+
[Summary]: ...
|
| 94 |
+
|
| 95 |
+
Some recent basic financials of {name}, reported at {date}, are presented below:
|
| 96 |
+
|
| 97 |
+
[Basic Financials]:
|
| 98 |
+
{attr1}: {value1}
|
| 99 |
+
{attr2}: {value2}
|
| 100 |
+
...
|
| 101 |
+
|
| 102 |
+
Based on all the information before {curday}, let's first analyze the positive developments and potential concerns for {symbol}. Come up with 2-4 most important factors respectively and keep them concise. Most factors should be inferred from company-related news. Then make your prediction of the {symbol} stock price movement for next week ({period}). Provide a summary analysis to support your prediction.
|
| 103 |
+
|
| 104 |
+
"""
|
| 105 |
+
```
|
| 106 |
+
## Train your own FinGPT-Forecaster
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
**Disclaimer: Nothing herein is financial advice, and NOT a recommendation to trade real money. Please use common sense and always first consult a professional before trading or investing.**
|
fingpt/FinGPT_Forecaster/AAAI-Good-Data/Testing.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
fingpt/FinGPT_Forecaster/AAAI-Good-Data/Training.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
fingpt/FinGPT_Forecaster/AAAI-Good-Data/config.json
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"train_micro_batch_size_per_gpu": 2,
|
| 3 |
+
"train_batch_size": 16,
|
| 4 |
+
"gradient_accumulation_steps": 8,
|
| 5 |
+
"optimizer": {
|
| 6 |
+
"type": "Adam",
|
| 7 |
+
"params": {
|
| 8 |
+
"lr": 5e-5,
|
| 9 |
+
"weight_decay": 0.01,
|
| 10 |
+
"bias_correction": false
|
| 11 |
+
}
|
| 12 |
+
},
|
| 13 |
+
"scheduler": {
|
| 14 |
+
"type": "WarmupLR",
|
| 15 |
+
"params": {
|
| 16 |
+
"warmup_min_lr": 0,
|
| 17 |
+
"warmup_max_lr": 5e-5,
|
| 18 |
+
"warmup_num_steps": "auto"
|
| 19 |
+
}
|
| 20 |
+
},
|
| 21 |
+
"fp16": {
|
| 22 |
+
"enabled": true
|
| 23 |
+
},
|
| 24 |
+
"bf16": {
|
| 25 |
+
"enabled": false
|
| 26 |
+
},
|
| 27 |
+
"zero_optimization": {
|
| 28 |
+
"stage": 2,
|
| 29 |
+
"offload_optimizer": {
|
| 30 |
+
"device": "none"
|
| 31 |
+
},
|
| 32 |
+
"offload_param": {
|
| 33 |
+
"device": "none"
|
| 34 |
+
}
|
| 35 |
+
},
|
| 36 |
+
"activation_checkpointing": {
|
| 37 |
+
"partition_activations": true,
|
| 38 |
+
"contiguous_memory_optimization": true
|
| 39 |
+
}
|
| 40 |
+
}
|
| 41 |
+
|
fingpt/FinGPT_Forecaster/AAAI-Good-Data/train.sh
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export NCCL_IGNORE_DISABLED_P2P=1
|
| 2 |
+
export TRANSFORMERS_NO_ADVISORY_WARNINGS=1
|
| 3 |
+
export TOKENIZERS_PARALLELISM=0
|
| 4 |
+
|
| 5 |
+
deepspeed \
|
| 6 |
+
train_lora.py \
|
| 7 |
+
--run_name llama3-8b-a100-5e-5lr \
|
| 8 |
+
--base_model llama3 \
|
| 9 |
+
--dataset "/content/drive/MyDrive/Colab Notebooks/AI4Finance/FinForecaster/Benchmark with Llama3 8b Data/fingpt-forecaster-1105/train/" \
|
| 10 |
+
--test_dataset "/content/drive/MyDrive/Colab Notebooks/AI4Finance/FinForecaster/Benchmark with Llama3 8b Data/fingpt-forecaster-1105/test/" \
|
| 11 |
+
--max_length 8000 \
|
| 12 |
+
--batch_size 2 \
|
| 13 |
+
--gradient_accumulation_steps 8 \
|
| 14 |
+
--learning_rate 5e-5 \
|
| 15 |
+
--num_epochs 5 \
|
| 16 |
+
--log_interval 10 \
|
| 17 |
+
--warmup_ratio 0.03 \
|
| 18 |
+
--scheduler constant \
|
| 19 |
+
--evaluation_strategy steps \
|
| 20 |
+
--ds_config config.json \
|
| 21 |
+
|
fingpt/FinGPT_Forecaster/AAAI-Good-Data/train_lora.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers.integrations import TensorBoardCallback
|
| 2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 3 |
+
from transformers import TrainingArguments, Trainer, DataCollatorForSeq2Seq
|
| 4 |
+
from transformers import TrainerCallback, TrainerState, TrainerControl
|
| 5 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 6 |
+
import datasets
|
| 7 |
+
import torch
|
| 8 |
+
import sys
|
| 9 |
+
import os
|
| 10 |
+
import re
|
| 11 |
+
import wandb
|
| 12 |
+
import argparse
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
from functools import partial
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
from utils import lora_module_dict, parse_model_name, load_dataset, tokenize, calc_metrics
|
| 17 |
+
|
| 18 |
+
# LoRA
|
| 19 |
+
from peft import (
|
| 20 |
+
TaskType,
|
| 21 |
+
LoraConfig,
|
| 22 |
+
get_peft_model,
|
| 23 |
+
set_peft_model_state_dict,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
# Replace with your own api_key and project name
|
| 27 |
+
os.environ['WANDB_API_KEY'] = '9eb3b8f122ddd8fb07c4cd438055c48f7a3b6951' # TODO: Replace with your environment variable
|
| 28 |
+
os.environ['WANDB_PROJECT'] = 'Benchmark with Llama-3-8B'
|
| 29 |
+
|
| 30 |
+
class GenerationEvalCallback(TrainerCallback):
|
| 31 |
+
def __init__(self, eval_dataset, ignore_until_epoch=0):
|
| 32 |
+
self.eval_dataset = eval_dataset
|
| 33 |
+
self.ignore_until_epoch = ignore_until_epoch
|
| 34 |
+
|
| 35 |
+
def on_evaluate(self, args, state: TrainerState, control: TrainerControl, **kwargs):
|
| 36 |
+
if state.epoch is None or state.epoch + 1 < self.ignore_until_epoch:
|
| 37 |
+
return
|
| 38 |
+
|
| 39 |
+
if state.is_local_process_zero:
|
| 40 |
+
model = kwargs['model']
|
| 41 |
+
tokenizer = kwargs['tokenizer']
|
| 42 |
+
generated_texts, reference_texts = [], []
|
| 43 |
+
|
| 44 |
+
for feature in tqdm(self.eval_dataset):
|
| 45 |
+
prompt = feature['prompt']
|
| 46 |
+
gt = feature['answer']
|
| 47 |
+
inputs = tokenizer(
|
| 48 |
+
prompt, return_tensors='pt',
|
| 49 |
+
padding=False, max_length=8000
|
| 50 |
+
)
|
| 51 |
+
inputs = {key: value.to(model.device) for key, value in inputs.items()}
|
| 52 |
+
|
| 53 |
+
res = model.generate(
|
| 54 |
+
**inputs,
|
| 55 |
+
use_cache=True
|
| 56 |
+
)
|
| 57 |
+
output = tokenizer.decode(res[0], skip_special_tokens=True)
|
| 58 |
+
answer = re.sub(r'.*\[/INST\]\s*', '', output, flags=re.DOTALL)
|
| 59 |
+
|
| 60 |
+
generated_texts.append(answer)
|
| 61 |
+
reference_texts.append(gt)
|
| 62 |
+
|
| 63 |
+
metrics = calc_metrics(reference_texts, generated_texts)
|
| 64 |
+
|
| 65 |
+
# Ensure wandb is initialized
|
| 66 |
+
if wandb.run is None:
|
| 67 |
+
wandb.init()
|
| 68 |
+
|
| 69 |
+
wandb.log(metrics, step=state.global_step)
|
| 70 |
+
torch.cuda.empty_cache()
|
| 71 |
+
|
| 72 |
+
def main(args):
|
| 73 |
+
model_name = parse_model_name(args.base_model, args.from_remote)
|
| 74 |
+
|
| 75 |
+
# Load Llama3 model
|
| 76 |
+
if args.base_model == 'llama3':
|
| 77 |
+
model_name = 'meta-llama/Meta-Llama-3-8B' # Replace with correct Llama3 model path or identifier
|
| 78 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 79 |
+
model_name,
|
| 80 |
+
trust_remote_code=True
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 84 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 85 |
+
tokenizer.padding_side = "right"
|
| 86 |
+
|
| 87 |
+
# Load data
|
| 88 |
+
dataset_list = load_dataset(args.dataset, args.from_remote)
|
| 89 |
+
dataset_train = datasets.concatenate_datasets([d['train'] for d in dataset_list]).shuffle(seed=42)
|
| 90 |
+
|
| 91 |
+
if args.test_dataset:
|
| 92 |
+
test_dataset_list = load_dataset(args.test_dataset, args.from_remote)
|
| 93 |
+
dataset_test = datasets.concatenate_datasets([d['test'] for d in test_dataset_list])
|
| 94 |
+
|
| 95 |
+
original_dataset = datasets.DatasetDict({'train': dataset_train, 'test': dataset_test})
|
| 96 |
+
eval_dataset = original_dataset['test'].shuffle(seed=42).select(range(50))
|
| 97 |
+
|
| 98 |
+
dataset = original_dataset.map(partial(tokenize, args, tokenizer))
|
| 99 |
+
dataset = dataset.filter(lambda x: not x['exceed_max_length'])
|
| 100 |
+
dataset = dataset.remove_columns(
|
| 101 |
+
['prompt', 'answer', 'label', 'symbol', 'period', 'exceed_max_length']
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
current_time = datetime.now()
|
| 105 |
+
formatted_time = current_time.strftime('%Y%m%d%H%M')
|
| 106 |
+
|
| 107 |
+
training_args = TrainingArguments(
|
| 108 |
+
output_dir=f'finetuned_models/{args.run_name}_{formatted_time}', # Save location
|
| 109 |
+
logging_steps=args.log_interval,
|
| 110 |
+
num_train_epochs=args.num_epochs,
|
| 111 |
+
per_device_train_batch_size=args.batch_size,
|
| 112 |
+
per_device_eval_batch_size=args.batch_size,
|
| 113 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 114 |
+
dataloader_num_workers=args.num_workers,
|
| 115 |
+
learning_rate=args.learning_rate,
|
| 116 |
+
weight_decay=args.weight_decay,
|
| 117 |
+
warmup_ratio=args.warmup_ratio,
|
| 118 |
+
lr_scheduler_type=args.scheduler,
|
| 119 |
+
save_steps=args.eval_steps,
|
| 120 |
+
eval_steps=args.eval_steps,
|
| 121 |
+
fp16=True,
|
| 122 |
+
deepspeed=args.ds_config,
|
| 123 |
+
evaluation_strategy=args.evaluation_strategy,
|
| 124 |
+
remove_unused_columns=False,
|
| 125 |
+
report_to='wandb',
|
| 126 |
+
run_name=args.run_name
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
model.gradient_checkpointing_enable()
|
| 130 |
+
model.enable_input_require_grads()
|
| 131 |
+
model.is_parallelizable = True
|
| 132 |
+
model.model_parallel = True
|
| 133 |
+
model.config.use_cache = False
|
| 134 |
+
|
| 135 |
+
# Setup PEFT with LoRA
|
| 136 |
+
peft_config = LoraConfig(
|
| 137 |
+
task_type=TaskType.CAUSAL_LM,
|
| 138 |
+
inference_mode=False,
|
| 139 |
+
r=8,
|
| 140 |
+
lora_alpha=16,
|
| 141 |
+
lora_dropout=0.1,
|
| 142 |
+
target_modules=lora_module_dict[args.base_model],
|
| 143 |
+
bias='none',
|
| 144 |
+
)
|
| 145 |
+
model = get_peft_model(model, peft_config)
|
| 146 |
+
|
| 147 |
+
# Train
|
| 148 |
+
trainer = Trainer(
|
| 149 |
+
model=model,
|
| 150 |
+
args=training_args,
|
| 151 |
+
train_dataset=dataset['train'],
|
| 152 |
+
eval_dataset=dataset['test'],
|
| 153 |
+
tokenizer=tokenizer,
|
| 154 |
+
data_collator=DataCollatorForSeq2Seq(
|
| 155 |
+
tokenizer, padding=True,
|
| 156 |
+
return_tensors="pt"
|
| 157 |
+
),
|
| 158 |
+
callbacks=[
|
| 159 |
+
GenerationEvalCallback(
|
| 160 |
+
eval_dataset=eval_dataset,
|
| 161 |
+
ignore_until_epoch=round(0.3 * args.num_epochs)
|
| 162 |
+
)
|
| 163 |
+
]
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
if torch.__version__ >= "2" and sys.platform != "win32":
|
| 167 |
+
model = torch.compile(model)
|
| 168 |
+
|
| 169 |
+
torch.cuda.empty_cache()
|
| 170 |
+
trainer.train()
|
| 171 |
+
|
| 172 |
+
# Save model
|
| 173 |
+
model.save_pretrained(training_args.output_dir)
|
| 174 |
+
|
| 175 |
+
if __name__ == "__main__":
|
| 176 |
+
parser = argparse.ArgumentParser()
|
| 177 |
+
parser.add_argument("--local_rank", default=0, type=int)
|
| 178 |
+
parser.add_argument("--run_name", default='local-test', type=str)
|
| 179 |
+
parser.add_argument("--dataset", required=True, type=str)
|
| 180 |
+
parser.add_argument("--test_dataset", type=str)
|
| 181 |
+
parser.add_argument("--base_model", required=True, type=str, choices=['chatglm2', 'llama2', 'llama3', 'llama3.1'])
|
| 182 |
+
parser.add_argument("--max_length", default=512, type=int)
|
| 183 |
+
parser.add_argument("--batch_size", default=4, type=int, help="The train batch size per device")
|
| 184 |
+
parser.add_argument("--learning_rate", default=1e-4, type=float, help="The learning rate")
|
| 185 |
+
parser.add_argument("--weight_decay", default=0.01, type=float, help="weight decay")
|
| 186 |
+
parser.add_argument("--num_epochs", default=8, type=float, help="The training epochs")
|
| 187 |
+
parser.add_argument("--num_workers", default=8, type=int, help="dataloader workers")
|
| 188 |
+
parser.add_argument("--log_interval", default=20, type=int)
|
| 189 |
+
parser.add_argument("--gradient_accumulation_steps", default=8, type=int)
|
| 190 |
+
parser.add_argument("--warmup_ratio", default=0.05, type=float)
|
| 191 |
+
parser.add_argument("--ds_config", default='./config_new.json', type=str)
|
| 192 |
+
parser.add_argument("--scheduler", default='linear', type=str)
|
| 193 |
+
parser.add_argument("--instruct_template", default='default')
|
| 194 |
+
parser.add_argument("--evaluation_strategy", default='steps', type=str)
|
| 195 |
+
parser.add_argument("--eval_steps", default=0.1, type=float)
|
| 196 |
+
parser.add_argument("--from_remote", default=False, type=bool)
|
| 197 |
+
args = parser.parse_args()
|
| 198 |
+
|
| 199 |
+
wandb.login()
|
| 200 |
+
main(args)
|
fingpt/FinGPT_Forecaster/AAAI-Good-Data/utils.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import os
|
| 3 |
+
import datasets
|
| 4 |
+
from sklearn.metrics import accuracy_score, mean_squared_error
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from rouge_score import rouge_scorer
|
| 7 |
+
|
| 8 |
+
# 支持 Llama3 的 LoRA 模块定义
|
| 9 |
+
lora_module_dict = {
|
| 10 |
+
'chatglm2': ['query_key_value'],
|
| 11 |
+
'llama2': [
|
| 12 |
+
'q_proj', 'k_proj', 'v_proj',
|
| 13 |
+
'o_proj', 'gate_proj', 'up_proj', 'down_proj',
|
| 14 |
+
],
|
| 15 |
+
'llama3': [ # 适配 Llama3-8b 的模块
|
| 16 |
+
'q_proj', 'k_proj', 'v_proj',
|
| 17 |
+
'o_proj', 'gate_proj', 'up_proj', 'down_proj',
|
| 18 |
+
],
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def tokenize(args, tokenizer, feature):
|
| 23 |
+
prompt_ids = tokenizer.encode(
|
| 24 |
+
feature['prompt'].strip(), padding=False,
|
| 25 |
+
max_length=args.max_length, truncation=True
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
target_ids = tokenizer.encode(
|
| 29 |
+
feature['answer'].strip(), padding=False,
|
| 30 |
+
max_length=args.max_length, truncation=True, add_special_tokens=False
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
input_ids = prompt_ids + target_ids
|
| 34 |
+
exceed_max_length = len(input_ids) >= args.max_length
|
| 35 |
+
|
| 36 |
+
# Add EOS Token
|
| 37 |
+
if input_ids[-1] != tokenizer.eos_token_id and not exceed_max_length:
|
| 38 |
+
input_ids.append(tokenizer.eos_token_id)
|
| 39 |
+
|
| 40 |
+
label_ids = [tokenizer.pad_token_id] * len(prompt_ids) + input_ids[len(prompt_ids):]
|
| 41 |
+
|
| 42 |
+
return {
|
| 43 |
+
"input_ids": input_ids,
|
| 44 |
+
"labels": label_ids,
|
| 45 |
+
"exceed_max_length": exceed_max_length
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def parse_model_name(name, from_remote=False):
|
| 50 |
+
if name == 'chatglm2':
|
| 51 |
+
return 'THUDM/chatglm2-6b' if from_remote else 'base_models/chatglm2-6b'
|
| 52 |
+
elif name == 'llama2':
|
| 53 |
+
return 'meta-llama/Llama-2-7b-chat-hf'
|
| 54 |
+
elif name == 'llama3':
|
| 55 |
+
return 'meta-llama/Llama-3-8B' # 适配 Llama3-8b
|
| 56 |
+
else:
|
| 57 |
+
raise ValueError(f"Undefined base model {name}")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def load_dataset(names, from_hf_hub=False):
|
| 61 |
+
"""
|
| 62 |
+
加载数据集,可以从本地或者 Hugging Face Hub 上加载
|
| 63 |
+
names: 数据集名称,支持多个数据集逗号分隔
|
| 64 |
+
from_hf_hub: 是否从 Hugging Face Hub 上加载数据集
|
| 65 |
+
"""
|
| 66 |
+
dataset_names = [d for d in names.split(',')]
|
| 67 |
+
dataset_list = []
|
| 68 |
+
|
| 69 |
+
for name in dataset_names:
|
| 70 |
+
rep = 1
|
| 71 |
+
if from_hf_hub:
|
| 72 |
+
# 从 Hugging Face Hub 加载数据集
|
| 73 |
+
tmp_dataset = datasets.load_dataset(name)
|
| 74 |
+
else:
|
| 75 |
+
# 从本地加载数据集(假设是 Arrow 格式的 .arrow 文件)
|
| 76 |
+
if os.path.exists(name):
|
| 77 |
+
tmp_dataset = datasets.load_from_disk(name) # 本地加载
|
| 78 |
+
else:
|
| 79 |
+
raise FileNotFoundError(f"Dataset {name} not found in the specified path.")
|
| 80 |
+
|
| 81 |
+
# 如果数据集中没有 'test' 集,则按照 80/20 比例进行分割
|
| 82 |
+
if 'test' not in tmp_dataset:
|
| 83 |
+
tmp_dataset = tmp_dataset.train_test_split(0.2, shuffle=True, seed=42)
|
| 84 |
+
|
| 85 |
+
dataset_list.extend([tmp_dataset] * rep)
|
| 86 |
+
|
| 87 |
+
return dataset_list
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def parse_answer(answer):
|
| 91 |
+
match_res = re.match(
|
| 92 |
+
r"^\s*\[Positive Developments\]:\s*(.*)\s*\[Potential Concerns\]:\s*(.*)\s*\[Prediction (&|and) Analysis\]:\s*(.*)\s*$",
|
| 93 |
+
answer, flags=re.DOTALL)
|
| 94 |
+
if not match_res:
|
| 95 |
+
return None
|
| 96 |
+
|
| 97 |
+
pros, cons, pna = match_res.group(1), match_res.group(2), match_res.group(4)
|
| 98 |
+
|
| 99 |
+
match_res = re.match(r'^Prediction:\s*(.*)\s*Analysis:\s*(.*)\s*$', pna, flags=re.DOTALL)
|
| 100 |
+
if not match_res:
|
| 101 |
+
return None
|
| 102 |
+
|
| 103 |
+
pred, anal = match_res.group(1), match_res.group(2)
|
| 104 |
+
|
| 105 |
+
if re.search(r'up|increase', pred.lower()):
|
| 106 |
+
pred_bin = 1
|
| 107 |
+
elif re.search(r'down|decrease|decline', pred.lower()):
|
| 108 |
+
pred_bin = -1
|
| 109 |
+
else:
|
| 110 |
+
pred_bin = 0
|
| 111 |
+
|
| 112 |
+
match_res = re.search(r'(\d)-(\d)%', pred)
|
| 113 |
+
if not match_res:
|
| 114 |
+
match_res = re.search(r'(?:more than )?(\d)+?%', pred)
|
| 115 |
+
|
| 116 |
+
pred_margin = pred_bin * (int(match_res.group(1)) + 0.5) if match_res else 0.
|
| 117 |
+
|
| 118 |
+
return {
|
| 119 |
+
"positive developments": pros,
|
| 120 |
+
"potential concerns": cons,
|
| 121 |
+
"prediction": pred_margin,
|
| 122 |
+
"prediction_binary": pred_bin,
|
| 123 |
+
"analysis": anal
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def calc_rouge_score(references, answers):
|
| 128 |
+
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
|
| 129 |
+
scores_per_pair = [scorer.score(ref, ans) for ref, ans in zip(references, answers)]
|
| 130 |
+
|
| 131 |
+
rouge1 = sum(score['rouge1'].fmeasure for score in scores_per_pair) / len(scores_per_pair)
|
| 132 |
+
rouge2 = sum(score['rouge2'].fmeasure for score in scores_per_pair) / len(scores_per_pair)
|
| 133 |
+
rougeL = sum(score['rougeL'].fmeasure for score in scores_per_pair) / len(scores_per_pair)
|
| 134 |
+
|
| 135 |
+
return {'rouge1': rouge1, 'rouge2': rouge2, 'rougeL': rougeL}
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def calc_metrics(answers, gts):
|
| 139 |
+
answers_dict = defaultdict(list)
|
| 140 |
+
gts_dict = defaultdict(list)
|
| 141 |
+
|
| 142 |
+
for answer, gt in zip(answers, gts):
|
| 143 |
+
answer_dict = parse_answer(answer)
|
| 144 |
+
gt_dict = parse_answer(gt)
|
| 145 |
+
|
| 146 |
+
if answer_dict and gt_dict:
|
| 147 |
+
for k in answer_dict.keys():
|
| 148 |
+
answers_dict[k].append(answer_dict[k])
|
| 149 |
+
gts_dict[k].append(gt_dict[k])
|
| 150 |
+
|
| 151 |
+
if not answers_dict['prediction']:
|
| 152 |
+
return {}
|
| 153 |
+
|
| 154 |
+
bin_acc = accuracy_score(gts_dict['prediction_binary'], answers_dict['prediction_binary'])
|
| 155 |
+
mse = mean_squared_error(gts_dict['prediction'], answers_dict['prediction'])
|
| 156 |
+
|
| 157 |
+
pros_rouge_scores = calc_rouge_score(gts_dict['positive developments'], answers_dict['positive developments'])
|
| 158 |
+
cons_rouge_scores = calc_rouge_score(gts_dict['potential concerns'], answers_dict['potential concerns'])
|
| 159 |
+
anal_rouge_scores = calc_rouge_score(gts_dict['analysis'], answers_dict['analysis'])
|
| 160 |
+
|
| 161 |
+
print(f"\nBinary Accuracy: {bin_acc:.2f} | Mean Square Error: {mse:.2f}")
|
| 162 |
+
print(f"\nRouge Score of Positive Developments: {pros_rouge_scores}")
|
| 163 |
+
print(f"\nRouge Score of Potential Concerns: {cons_rouge_scores}")
|
| 164 |
+
print(f"\nRouge Score of Summary Analysis: {anal_rouge_scores}")
|
| 165 |
+
|
| 166 |
+
return {
|
| 167 |
+
"valid_count": len(answers_dict['prediction']),
|
| 168 |
+
"bin_acc": bin_acc,
|
| 169 |
+
"mse": mse,
|
| 170 |
+
"pros_rouge_scores": pros_rouge_scores,
|
| 171 |
+
"cons_rouge_scores": cons_rouge_scores,
|
| 172 |
+
"anal_rouge_scores": anal_rouge_scores
|
| 173 |
+
}
|