Skip to content

Commit 6d7c738

Browse files
kumare3pingsutw
andauthored
Webhook tasks using FlyteAgents (#3058)
Signed-off-by: Ketan Umare <[email protected]> Signed-off-by: Kevin Su <[email protected]> Co-authored-by: Ketan Umare <[email protected]> Co-authored-by: Kevin Su <[email protected]>
1 parent 9acab29 commit 6d7c738

File tree

15 files changed

+614
-86
lines changed

15 files changed

+614
-86
lines changed

dev-requirements.in

+1
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,4 @@ ipykernel
6060

6161
orjson
6262
kubernetes>=12.0.1
63+
httpx

flytekit/clis/sdk_in_container/serve.py

+2
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def agent(_: click.Context, port, prometheus_port, worker, timeout, modules):
6868
"""
6969
import asyncio
7070

71+
from flytekit.extras.webhook import WebhookTask # noqa: F401
72+
7173
working_dir = os.getcwd()
7274
if all(os.path.realpath(path) != working_dir for path in sys.path):
7375
sys.path.append(working_dir)

flytekit/extend/backend/base_agent.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
from flytekit.models.literals import LiteralMap
3535
from flytekit.models.task import TaskExecutionMetadata, TaskTemplate
3636

37+
# It's used to force agent to run in the same event loop in the local execution.
38+
local_agent_loop = asyncio.new_event_loop()
39+
3740

3841
class TaskCategory:
3942
def __init__(self, name: str, version: int = 0):
@@ -285,7 +288,7 @@ def execute(self: PythonTask, **kwargs) -> LiteralMap:
285288
output_prefix = ctx.file_access.get_random_remote_directory()
286289

287290
agent = AgentRegistry.get_agent(task_template.type, task_template.task_type_version)
288-
resource = asyncio.run(
291+
resource = local_agent_loop.run_until_complete(
289292
self._do(agent=agent, template=task_template, output_prefix=output_prefix, inputs=kwargs)
290293
)
291294
if resource.phase != TaskExecution.SUCCEEDED:
@@ -335,10 +338,10 @@ def execute(self: PythonTask, **kwargs) -> LiteralMap:
335338
task_template = get_serializable(OrderedDict(), ss, self).template
336339
self._agent = AgentRegistry.get_agent(task_template.type, task_template.task_type_version)
337340

338-
resource_meta = asyncio.run(
341+
resource_meta = local_agent_loop.run_until_complete(
339342
self._create(task_template=task_template, output_prefix=output_prefix, inputs=kwargs)
340343
)
341-
resource = asyncio.run(self._get(resource_meta=resource_meta))
344+
resource = local_agent_loop.run_until_complete(self._get(resource_meta=resource_meta))
342345

343346
if resource.phase != TaskExecution.SUCCEEDED:
344347
raise FlyteUserException(f"Failed to run the task {self.name} with error: {resource.message}")

flytekit/extras/webhook/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .agent import WebhookAgent
2+
from .task import WebhookTask
3+
4+
__all__ = ["WebhookTask", "WebhookAgent"]

flytekit/extras/webhook/agent.py

+101
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
from typing import Optional
2+
3+
import httpx
4+
from flyteidl.core.execution_pb2 import TaskExecution
5+
6+
from flytekit.extend.backend.base_agent import AgentRegistry, Resource, SyncAgentBase
7+
from flytekit.interaction.string_literals import literal_map_string_repr
8+
from flytekit.models.literals import LiteralMap
9+
from flytekit.models.task import TaskTemplate
10+
from flytekit.utils.dict_formatter import format_dict
11+
12+
from .constants import DATA_KEY, HEADERS_KEY, METHOD_KEY, SHOW_DATA_KEY, SHOW_URL_KEY, TASK_TYPE, TIMEOUT_SEC, URL_KEY
13+
14+
15+
class WebhookAgent(SyncAgentBase):
16+
"""
17+
WebhookAgent is responsible for handling webhook tasks.
18+
19+
This agent sends HTTP requests based on the task template and inputs provided,
20+
and processes the responses to determine the success or failure of the task.
21+
22+
:param client: An optional HTTP client to use for sending requests.
23+
"""
24+
25+
name: str = "Webhook Agent"
26+
27+
def __init__(self, client: Optional[httpx.AsyncClient] = None):
28+
super().__init__(task_type_name=TASK_TYPE)
29+
self._client = client or httpx.AsyncClient()
30+
31+
async def do(
32+
self, task_template: TaskTemplate, output_prefix: str, inputs: Optional[LiteralMap] = None, **kwargs
33+
) -> Resource:
34+
"""
35+
This method processes the webhook task and sends an HTTP request.
36+
37+
It uses asyncio to send the request and process the response using the httpx library.
38+
"""
39+
try:
40+
final_dict = self._get_final_dict(task_template, inputs)
41+
return await self._process_webhook(final_dict)
42+
except Exception as e:
43+
return Resource(phase=TaskExecution.FAILED, message=str(e))
44+
45+
def _get_final_dict(self, task_template: TaskTemplate, inputs: LiteralMap) -> dict:
46+
custom_dict = task_template.custom
47+
input_dict = {
48+
"inputs": literal_map_string_repr(inputs),
49+
}
50+
return format_dict("test", custom_dict, input_dict)
51+
52+
async def _make_http_request(self, method: str, url: str, headers: dict, data: dict, timeout: int) -> tuple:
53+
if method == "GET":
54+
response = await self._client.get(url, headers=headers, params=data, timeout=timeout)
55+
else:
56+
response = await self._client.post(url, json=data, headers=headers, timeout=timeout)
57+
return response.status_code, response.text
58+
59+
@staticmethod
60+
def _build_response(
61+
status: int,
62+
text: str,
63+
data: dict = None,
64+
url: str = None,
65+
show_data: bool = False,
66+
show_url: bool = False,
67+
) -> dict:
68+
final_response = {
69+
"status_code": status,
70+
"response_data": text,
71+
}
72+
if show_data:
73+
final_response["input_data"] = data
74+
if show_url:
75+
final_response["url"] = url
76+
return final_response
77+
78+
async def _process_webhook(self, final_dict: dict) -> Resource:
79+
url = final_dict.get(URL_KEY)
80+
body = final_dict.get(DATA_KEY)
81+
headers = final_dict.get(HEADERS_KEY)
82+
method = str(final_dict.get(METHOD_KEY)).upper()
83+
show_data = final_dict.get(SHOW_DATA_KEY, False)
84+
show_url = final_dict.get(SHOW_URL_KEY, False)
85+
timeout_sec = final_dict.get(TIMEOUT_SEC, 10)
86+
87+
status, text = await self._make_http_request(method, url, headers, body, timeout_sec)
88+
if status != 200:
89+
return Resource(
90+
phase=TaskExecution.FAILED,
91+
message=f"Webhook failed with status code {status}, response: {text}",
92+
)
93+
final_response = self._build_response(status, text, body, url, show_data, show_url)
94+
return Resource(
95+
phase=TaskExecution.SUCCEEDED,
96+
outputs={"info": final_response},
97+
message="Webhook was successfully invoked!",
98+
)
99+
100+
101+
AgentRegistry.register(WebhookAgent())

flytekit/extras/webhook/constants.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
TASK_TYPE: str = "webhook"
2+
3+
URL_KEY: str = "url"
4+
METHOD_KEY: str = "method"
5+
HEADERS_KEY: str = "headers"
6+
DATA_KEY: str = "data"
7+
SHOW_DATA_KEY: str = "show_data"
8+
SHOW_URL_KEY: str = "show_url"
9+
TIMEOUT_SEC: str = "timeout_sec"

flytekit/extras/webhook/task.py

+141
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
from datetime import timedelta
2+
from typing import Any, Dict, Optional, Type, Union
3+
4+
from flytekit import Documentation
5+
from flytekit.configuration import SerializationSettings
6+
from flytekit.core.base_task import PythonTask
7+
from flytekit.extend.backend.base_agent import SyncAgentExecutorMixin
8+
9+
from ...core.interface import Interface
10+
from .constants import DATA_KEY, HEADERS_KEY, METHOD_KEY, SHOW_DATA_KEY, SHOW_URL_KEY, TASK_TYPE, TIMEOUT_SEC, URL_KEY
11+
12+
13+
class WebhookTask(SyncAgentExecutorMixin, PythonTask):
14+
"""
15+
The WebhookTask is used to invoke a webhook. The webhook can be invoked with a POST or GET method.
16+
17+
All the parameters can be formatted using python format strings.
18+
19+
Example:
20+
```python
21+
simple_get = WebhookTask(
22+
name="simple-get",
23+
url="http://localhost:8000/",
24+
method=http.HTTPMethod.GET,
25+
headers={"Content-Type": "application/json"},
26+
)
27+
28+
get_with_params = WebhookTask(
29+
name="get-with-params",
30+
url="http://localhost:8000/items/{inputs.item_id}",
31+
method=http.HTTPMethod.GET,
32+
headers={"Content-Type": "application/json"},
33+
dynamic_inputs={"s": str, "item_id": int},
34+
show_data=True,
35+
show_url=True,
36+
description="Test Webhook Task",
37+
data={"q": "{inputs.s}"},
38+
)
39+
40+
41+
@fk.workflow
42+
def wf(s: str) -> (dict, dict, dict):
43+
v = hello(s=s)
44+
w = WebhookTask(
45+
name="invoke-slack",
46+
url="https://hooks.slack.com/services/xyz/zaa/aaa",
47+
headers={"Content-Type": "application/json"},
48+
data={"text": "{inputs.s}"},
49+
show_data=True,
50+
show_url=True,
51+
description="Test Webhook Task",
52+
dynamic_inputs={"s": str},
53+
)
54+
return simple_get(), get_with_params(s=v, item_id=10), w(s=v)
55+
```
56+
57+
All the parameters can be formatted using python format strings. The following parameters are available for
58+
formatting:
59+
- dynamic_inputs: These are the dynamic inputs to the task. The keys are the names of the inputs and the values
60+
are the values of the inputs. All inputs are available under the prefix `inputs.`.
61+
For example, if the inputs are {"input1": 10, "input2": "hello"}, then you can
62+
use {inputs.input1} and {inputs.input2} in the URL and the body. Define the dynamic_inputs argument in the
63+
constructor to use these inputs. The dynamic inputs should not be actual values, but the types of the inputs.
64+
65+
TODO Coming soon secrets support
66+
- secrets: These are the secrets that are requested by the task. The keys are the names of the secrets and the
67+
values are the values of the secrets. All secrets are available under the prefix `secrets.`.
68+
For example, if the secret requested are Secret(name="secret1") and Secret(name="secret), then you can use
69+
{secrets.secret1} and {secrets.secret2} in the URL and the body. Define the secret_requests argument in the
70+
constructor to use these secrets. The secrets should not be actual values, but the types of the secrets.
71+
72+
:param name: Name of this task, should be unique in the project
73+
:param url: The endpoint or URL to invoke for this webhook. This can be a static string or a python format string,
74+
where the format arguments are the dynamic_inputs to the task, secrets etc. Refer to the description for more
75+
details of available formatting parameters.
76+
:param method: The HTTP method to use for the request. Default is POST.
77+
:param headers: The headers to send with the request. This can be a static dictionary or a python format string,
78+
where the format arguments are the dynamic_inputs to the task, secrets etc. Refer to the description for more
79+
details of available formatting parameters.
80+
:param data: The body to send with the request. This can be a static dictionary or a python format string,
81+
where the format arguments are the dynamic_inputs to the task, secrets etc. Refer to the description for more
82+
details of available formatting parameters. the data should be a json serializable dictionary and will be
83+
sent as the json body of the POST request and as the query parameters of the GET request.
84+
:param dynamic_inputs: The dynamic inputs to the task. The keys are the names of the inputs and the values
85+
are the types of the inputs. These inputs are available under the prefix `inputs.` to be used in the URL,
86+
headers and body and other formatted fields.
87+
:param secret_requests: The secrets that are requested by the task. (TODO not yet supported)
88+
:param show_data: If True, the body of the request will be logged in the UI as the output of the task.
89+
:param show_url: If True, the URL of the request will be logged in the UI as the output of the task.
90+
:param description: Description of the task
91+
:param timeout: The timeout for the request (connection and read). Default is 10 seconds. If int value is provided,
92+
it is considered as seconds.
93+
"""
94+
95+
def __init__(
96+
self,
97+
name: str,
98+
url: str,
99+
method: str = "POST",
100+
headers: Optional[Dict[str, str]] = None,
101+
data: Optional[Dict[str, Any]] = None,
102+
dynamic_inputs: Optional[Dict[str, Type]] = None,
103+
show_data: bool = False,
104+
show_url: bool = False,
105+
description: Optional[str] = None,
106+
timeout: Union[int, timedelta] = timedelta(seconds=10),
107+
# secret_requests: Optional[List[Secret]] = None, TODO Secret support is coming soon
108+
):
109+
if method not in {"GET", "POST"}:
110+
raise ValueError(f"Method should be either GET or POST. Got {method}")
111+
112+
interface = Interface(
113+
inputs=dynamic_inputs or {},
114+
outputs={"info": dict},
115+
)
116+
super().__init__(
117+
name=name,
118+
interface=interface,
119+
task_type=TASK_TYPE,
120+
# secret_requests=secret_requests,
121+
docs=Documentation(short_description=description) if description else None,
122+
)
123+
self._url = url
124+
self._method = method
125+
self._headers = headers
126+
self._data = data
127+
self._show_data = show_data
128+
self._show_url = show_url
129+
self._timeout_sec = timeout if isinstance(timeout, int) else timeout.total_seconds()
130+
131+
def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
132+
config = {
133+
URL_KEY: self._url,
134+
METHOD_KEY: self._method,
135+
HEADERS_KEY: self._headers or {},
136+
DATA_KEY: self._data or {},
137+
SHOW_DATA_KEY: self._show_data,
138+
SHOW_URL_KEY: self._show_url,
139+
TIMEOUT_SEC: self._timeout_sec,
140+
}
141+
return config

flytekit/utils/dict_formatter.py

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import re
2+
from typing import Any, Dict, Optional
3+
4+
5+
def get_nested_value(d: Dict[str, Any], keys: list[str]) -> Any:
6+
"""
7+
Retrieve the nested value from a dictionary based on a list of keys.
8+
"""
9+
for key in keys:
10+
if key not in d:
11+
raise ValueError(f"Could not find the key {key} in {d}.")
12+
d = d[key]
13+
return d
14+
15+
16+
def replace_placeholder(
17+
service: str,
18+
original_dict: str,
19+
placeholder: str,
20+
replacement: str,
21+
) -> str:
22+
"""
23+
Replace a placeholder in the original string and handle the specific logic for the sagemaker service and idempotence token.
24+
"""
25+
temp_dict = original_dict.replace(f"{{{placeholder}}}", replacement)
26+
if service == "sagemaker" and placeholder in [
27+
"inputs.idempotence_token",
28+
"idempotence_token",
29+
]:
30+
if len(temp_dict) > 63:
31+
truncated_token = replacement[: 63 - len(original_dict.replace(f"{{{placeholder}}}", ""))]
32+
return original_dict.replace(f"{{{placeholder}}}", truncated_token)
33+
else:
34+
return temp_dict
35+
return temp_dict
36+
37+
38+
def format_dict(
39+
service: str,
40+
original_dict: Any,
41+
update_dict: Dict[str, Any],
42+
idempotence_token: Optional[str] = None,
43+
) -> Any:
44+
"""
45+
Recursively update a dictionary with format strings with values from another dictionary where the keys match
46+
the format string. This goes a little beyond regular python string formatting and uses `.` to denote nested keys.
47+
48+
For example, if original_dict is {"EndpointConfigName": "{endpoint_config_name}"},
49+
and update_dict is {"endpoint_config_name": "my-endpoint-config"},
50+
then the result will be {"EndpointConfigName": "my-endpoint-config"}.
51+
52+
For nested keys if the original_dict is {"EndpointConfigName": "{inputs.endpoint_config_name}"},
53+
and update_dict is {"inputs": {"endpoint_config_name": "my-endpoint-config"}},
54+
then the result will be {"EndpointConfigName": "my-endpoint-config"}.
55+
56+
:param service: The AWS service to use
57+
:param original_dict: The dictionary to update (in place)
58+
:param update_dict: The dictionary to use for updating
59+
:param idempotence_token: Hash of config -- this is to ensure the execution ID is deterministic
60+
:return: The updated dictionary
61+
"""
62+
if original_dict is None:
63+
return None
64+
65+
if isinstance(original_dict, str) and "{" in original_dict and "}" in original_dict:
66+
matches = re.findall(r"\{([^}]+)\}", original_dict)
67+
for match in matches:
68+
if "." in match:
69+
keys = match.split(".")
70+
nested_value = get_nested_value(update_dict, keys)
71+
if f"{{{match}}}" == original_dict:
72+
return nested_value
73+
else:
74+
original_dict = replace_placeholder(service, original_dict, match, str(nested_value))
75+
elif match == "idempotence_token" and idempotence_token:
76+
original_dict = replace_placeholder(service, original_dict, match, idempotence_token)
77+
return original_dict
78+
79+
if isinstance(original_dict, list):
80+
return [format_dict(service, item, update_dict, idempotence_token) for item in original_dict]
81+
82+
if isinstance(original_dict, dict):
83+
for key, value in original_dict.items():
84+
original_dict[key] = format_dict(service, value, update_dict, idempotence_token)
85+
86+
return original_dict

0 commit comments

Comments
 (0)