Skip to content

Commit 6eecb8f

Browse files
authored
Improve error handling (#10)
* chore: update the exponential backoff error handling * chore: update versions of langgraph and langchain * chore: update poetry lock file * fix: remove the unused functions from util
1 parent ec5f316 commit 6eecb8f

10 files changed

+278
-228
lines changed

bedrock_deep_research.py

+24-15
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from bedrock_deep_research import BedrockDeepResearch
1414
from bedrock_deep_research.config import DEFAULT_TOPIC, SUPPORTED_MODELS, Configuration
1515
from bedrock_deep_research.model import Section
16+
from bedrock_deep_research.utils import CustomError
1617

1718
logger = logging.getLogger(__name__)
1819
LOGLEVEL = os.environ.get("LOGLEVEL", "INFO").upper()
@@ -146,20 +147,28 @@ def render_initial_form():
146147
with st.spinner(
147148
"Please wait while the article outline is being generated..."
148149
):
149-
response = st.session_state.bedrock_deep_research.start(
150-
topic)
151-
152-
logger.debug(f"Outline response: {response}")
153-
154-
state = st.session_state.bedrock_deep_research.get_state()
155-
156-
article = Article(
157-
title=state.values["title"],
158-
sections=state.values["sections"],
159-
)
160-
st.session_state.article = article.render_outline()
161-
st.session_state.stage = "outline_feedback"
162-
st.rerun()
150+
try:
151+
response = st.session_state.bedrock_deep_research.start(
152+
topic)
153+
except CustomError as e:
154+
logger.error(f"Bedrock ClientError: {e}")
155+
raise e
156+
157+
except Exception as e:
158+
logger.error(
159+
f"An error occurred while creating the outline: {e}")
160+
raise e
161+
else:
162+
logger.debug(f"Outline response: {response}")
163+
state = st.session_state.bedrock_deep_research.get_state()
164+
165+
article = Article(
166+
title=state.values["title"],
167+
sections=state.values["sections"],
168+
)
169+
st.session_state.article = article.render_outline()
170+
st.session_state.stage = "outline_feedback"
171+
st.rerun()
163172
except Exception as e:
164173
logger.error(f"An error occurred: {e}")
165174
raise
@@ -288,7 +297,7 @@ def on_accept_outline_button_click():
288297
st.session_state.text_error = ""
289298
st.rerun()
290299
except Exception as e:
291-
st.error(f"An error occurred: {e}")
300+
st.error(f"An error occurred in article creation: {e}")
292301

293302
except Exception as e:
294303
logger.error(f"An error occurred: {e}")

bedrock_deep_research/nodes/article_head_image_generator.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self, message):
2828
self.message = message
2929

3030

31-
@exponential_backoff_retry(Exception, max_retries=10)
31+
@exponential_backoff_retry(ClientError, max_retries=10)
3232
def generate_image(model_id, body):
3333
"""
3434
Generate an image using Amazon Nova Canvas model on demand.
@@ -44,7 +44,6 @@ def generate_image(model_id, body):
4444

4545
accept = "application/json"
4646
content_type = "application/json"
47-
4847
response = bedrock.invoke_model(
4948
body=body, modelId=model_id, accept=accept, contentType=content_type
5049
)
@@ -143,7 +142,6 @@ def __call__(self, state: ArticleState, config: RunnableConfig):
143142
logger.error("A bedrock client error occurred:", message)
144143
except Exception as e:
145144
logger.error(
146-
147145
"An error occurred during ArticleHeadImageGenerator:", e)
148146

149147
logger.info("Generated head image: %s", image_path)

bedrock_deep_research/nodes/article_outline_generator.py

-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def __call__(self, state: ArticleState, config: RunnableConfig):
6767
outline = self.generate_outline(
6868
configurable.planner_model, configurable.max_tokens, system_prompt, user_prompt)
6969

70-
7170
logger.info(f"Generated sections: {outline.sections}")
7271
sections = [
7372
Section(section_number=i, name=section.name,

bedrock_deep_research/nodes/final_sections_writer.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from botocore.exceptions import ClientError
12
from langchain_aws import ChatBedrock
23
from langchain_core.messages import HumanMessage, SystemMessage
34
from langchain_core.runnables import RunnableConfig
@@ -87,7 +88,7 @@ def __call__(self, state: SectionState, config: RunnableConfig):
8788

8889
return {"completed_sections": [section]}
8990

90-
@exponential_backoff_retry(Exception, max_retries=10)
91+
@exponential_backoff_retry(ClientError, max_retries=10)
9192
def _generate_final_sections(
9293
self,
9394
model: ChatBedrock,

bedrock_deep_research/nodes/initial_researcher.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
from typing import List
44

5+
from botocore.exceptions import ClientError
56
from langchain_aws import ChatBedrock
67
from langchain_core.messages import HumanMessage, SystemMessage
78
from langchain_core.runnables import RunnableConfig
@@ -61,7 +62,7 @@ def __call__(self, state: ArticleInputState, config: RunnableConfig):
6162

6263
return {"source_str": source_str}
6364

64-
@exponential_backoff_retry(Exception, max_retries=10)
65+
@exponential_backoff_retry(ClientError, max_retries=10)
6566
def generate_search_queries(self, model_id: str, max_tokens: int, system_prompt: str, user_prompt: str) -> List[str]:
6667
planner_model = ChatBedrock(
6768
model_id=model_id, max_tokens=max_tokens)

bedrock_deep_research/nodes/section_search_query_generator.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22

3+
from botocore.client import ClientError
34
from langchain_aws import ChatBedrock
45
from langchain_core.messages import HumanMessage, SystemMessage
56
from langchain_core.runnables import RunnableConfig
@@ -54,7 +55,7 @@ def __call__(self, state: SectionState, config: RunnableConfig):
5455
return {"search_queries": queries.queries}
5556

5657

57-
@exponential_backoff_retry(Exception, max_retries=10)
58+
@exponential_backoff_retry(ClientError, max_retries=10)
5859
def generate_section_queries(configurable: Configuration, section: Section) -> Queries:
5960
planner_model = ChatBedrock(
6061
model_id=configurable.planner_model, max_tokens=configurable.max_tokens

bedrock_deep_research/nodes/section_writer.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
from typing import List, Literal
33

4+
from botocore.exceptions import ClientError
45
from langchain_aws import ChatBedrock
56
from langchain_core.messages import HumanMessage, SystemMessage
67
from langchain_core.runnables import RunnableConfig
@@ -155,7 +156,7 @@ def __call__(self, state: SectionState, config: RunnableConfig) -> Command[Liter
155156
goto=SectionWebResearcher.N,
156157
)
157158

158-
@exponential_backoff_retry(Exception, max_retries=10)
159+
@exponential_backoff_retry(ClientError, max_retries=10)
159160
def _generate_section_content(
160161
self,
161162
model: ChatBedrock,
@@ -183,7 +184,7 @@ def _generate_section_content(
183184

184185
return section_content.content
185186

186-
@exponential_backoff_retry(Exception, max_retries=10)
187+
@exponential_backoff_retry(ClientError, max_retries=10)
187188
def _grade_section_content(
188189
self, model: ChatBedrock, system_prompt: str, section: Section
189190
) -> Feedback:

bedrock_deep_research/utils.py

+50-18
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,22 @@
11
import logging
22
import random
3+
import re
34
import time
45
from functools import wraps
56

7+
from botocore.exceptions import ClientError
8+
69
logger = logging.getLogger(__name__)
710

811

12+
class CustomError(Exception):
13+
def __init__(self, message):
14+
self.message = message
15+
16+
def __str__(self):
17+
return self.message
18+
19+
920
def exponential_backoff_retry(
1021
ExceptionToCheck, max_retries: int = 5, initial_delay: float = 1.0
1122
):
@@ -14,6 +25,7 @@ def exponential_backoff_retry(
1425
1526
Args:
1627
func: Function to retry
28+
ExceptionToCheck: Exception class to check for retrying
1729
max_retries: Maximum number of retry attempts
1830
initial_delay: Initial delay in seconds before first retry
1931
"""
@@ -26,26 +38,37 @@ def wrapper(*args, **kwargs):
2638
for attempt in range(max_retries + 1):
2739
try:
2840
return func(*args, **kwargs)
29-
30-
except ExceptionToCheck as e:
31-
32-
if attempt == max_retries:
41+
except ClientError as e:
42+
if e.response['Error']['Code'] == 'ExpiredTokenException':
3343
logger.error(
34-
f"Execution failed after {attempt} attempts")
44+
"Expired token error. Please check/update your Security Token included in the request")
45+
# Do not try max_retry times
46+
attempt = max_retries
47+
raise CustomError(
48+
message="Expired Token. Please update the AWS credentials, to connect to the boto Client.")
49+
elif e.response['Error']['Code'] == 'ThrottlingException':
50+
if attempt == max_retries:
51+
logger.error(
52+
f"Error code: {e.response['Error']['Code']}"
53+
f"Execution failed after {max_retries} attempts due to throttling. Try again later.")
54+
raise CustomError(
55+
message=f"Throttling Exception raised.. Retry limit of {max_retries} retries reached.")
56+
logger.info(
57+
f"Attempt {attempt+1} failed due to throttling. Retrying...")
58+
# Add jitter to avoid thundering herd problem
59+
jitter = random.uniform(0, 0.1 * delay)
60+
sleep_time = delay + jitter
61+
logger.debug(
62+
f"Retrying in {sleep_time:.2f} seconds..."
63+
)
64+
time.sleep(sleep_time)
65+
delay *= 2 # Exponential backoff
66+
else:
67+
logger.error(f"Client Error Raised: {e}")
3568
raise e
36-
37-
# Add jitter to avoid thundering herd problem
38-
jitter = random.uniform(0, 0.1 * delay)
39-
sleep_time = delay + jitter
40-
41-
logger.debug(
42-
f"Attempt {attempt + 1}/{max_retries} failed. {str(e)}"
43-
f"Retrying in {sleep_time:.2f} seconds..."
44-
)
45-
46-
time.sleep(sleep_time)
47-
delay *= 2 # Exponential backoff
48-
69+
except ExceptionToCheck as e:
70+
logger.error(f"Error raised by {func.__name__}: {e}")
71+
raise e
4972
return wrapper
5073

5174
return decorator
@@ -74,3 +97,12 @@ def format_web_search(search_response, max_tokens_per_source, include_raw_conten
7497
formatted_text += f"Full source content limited to {max_tokens_per_source} tokens: {raw_content}\n\n"
7598

7699
return formatted_text.strip()
100+
101+
102+
def extract_xml_content(text: str, tag_name: str) -> str | None:
103+
pattern = f"<{tag_name}>(.*?)</{tag_name}>"
104+
match = re.search(pattern, text, re.DOTALL)
105+
if match:
106+
return match.group(1).strip()
107+
else:
108+
return None

0 commit comments

Comments
 (0)