|
| 1 | +import os |
| 2 | +from typing import List |
| 3 | +import requests |
| 4 | +import json |
| 5 | +from dotenv import load_dotenv |
| 6 | + |
| 7 | +from langchain_groq import ChatGroq |
| 8 | +from langgraph.prebuilt import create_react_agent |
| 9 | +from langchain_core.tools import tool |
| 10 | +from langchain_core.messages import HumanMessage |
| 11 | +from langgraph.checkpoint.memory import MemorySaver |
| 12 | + |
| 13 | +load_dotenv() |
| 14 | +SECTORS_API_KEY = os.getenv("SECTORS_API_KEY") |
| 15 | + |
| 16 | +llm = ChatGroq(model="llama3-8b-8192") |
| 17 | + |
| 18 | +def retrieve_from_endpoint(url: str) -> dict: |
| 19 | + headers = {"Authorization": SECTORS_API_KEY} |
| 20 | + |
| 21 | + try: |
| 22 | + response = requests.get(url, headers=headers) |
| 23 | + response.raise_for_status() |
| 24 | + data = response.json() |
| 25 | + except requests.exceptions.HTTPError as err: |
| 26 | + raise SystemExit(err) |
| 27 | + return json.dumps(data) |
| 28 | + |
| 29 | +@tool |
| 30 | +def get_company_overview(stock: str) -> str: |
| 31 | + """ |
| 32 | + Get company overview, enter stock code (e.g BBRI, TLKM) |
| 33 | + """ |
| 34 | + |
| 35 | + url = f"https://api.sectors.app/v1/company/report/{stock}/?sections=overview" |
| 36 | + |
| 37 | + return retrieve_from_endpoint(url) |
| 38 | + |
| 39 | +@tool |
| 40 | +def get_sector_overview(sector: str) -> str: |
| 41 | + """ |
| 42 | + Get sector overview, enter sector name (e.g banks, housing estate development) |
| 43 | + """ |
| 44 | + |
| 45 | + url = f"https://api.sectors.app/v1/subsector/report/{sector}/" |
| 46 | + |
| 47 | + return retrieve_from_endpoint(url) |
| 48 | + |
| 49 | +def get_all_valid_subsector_slugs() -> str: |
| 50 | + """ |
| 51 | + Get all valid subsector slugs |
| 52 | + """ |
| 53 | + |
| 54 | + url = "https://api.sectors.app/v1/subsectors/" |
| 55 | + |
| 56 | + return retrieve_from_endpoint(url) |
| 57 | + |
| 58 | +def match_input_to_valid_subsector_slug( |
| 59 | + valid_subsector_slugs: List[str], |
| 60 | + user_input: str, |
| 61 | + fuzzy_threshold: int = 80, |
| 62 | + ) -> str: |
| 63 | + """ |
| 64 | + Match input to valid subsector slug |
| 65 | + """ |
| 66 | + from fuzzywuzzy import fuzz |
| 67 | + |
| 68 | + good_approximation = [] |
| 69 | + # Challenge (1) implement this here |
| 70 | + |
| 71 | + return good_approximation |
| 72 | + |
| 73 | + |
| 74 | +tools = [ |
| 75 | + get_company_overview, |
| 76 | + get_sector_overview, |
| 77 | + # ... other tools |
| 78 | +] |
| 79 | + |
| 80 | +memory = MemorySaver() |
| 81 | +system_message = "You are an expert tool calling agent meant for financial data retriever and summarization. Use tools to get the information you need, be descriptive, insightful and use the data you get to make high quality commentary." |
| 82 | + |
| 83 | +app = create_react_agent(llm, |
| 84 | + tools, |
| 85 | + state_modifier=system_message, |
| 86 | + checkpointer=memory |
| 87 | +) |
| 88 | + |
| 89 | + |
| 90 | +def chat(session_id: str, input: str) -> str: |
| 91 | + out = app.invoke( |
| 92 | + { |
| 93 | + "messages": [ |
| 94 | + HumanMessage( |
| 95 | + content=input, |
| 96 | + session_id=session_id, |
| 97 | + ) |
| 98 | + ] |
| 99 | + }, |
| 100 | + config={"configurable": {"thread_id": "supertype"}}, |
| 101 | + ) |
| 102 | + return f'🤖: {out["messages"][-1].content}' |
| 103 | + |
| 104 | +if __name__ == "__main__": |
| 105 | + |
| 106 | + valid_subsector_slugs = get_all_valid_subsector_slugs() |
| 107 | + subsector_slugs = [item.get('subsector') for item in eval(valid_subsector_slugs)] |
| 108 | + |
| 109 | + user_input = input("→: Enter a sector name (e.g. 'banks') or 4-digit ticker (e.g 'bmri'). Enter .q to exit. \n→: ") |
| 110 | + # Challenge (2) implement the fuzzy search here |
| 111 | + |
| 112 | + out = chat("supertype", f"Give me a company overview of {user_input}") |
| 113 | + print(out) |
| 114 | + |
| 115 | + |
| 116 | + |
| 117 | + |
0 commit comments