Skip to content

Commit d4185eb

Browse files
authored
Add Writer model support to ChatBedrock (#478)
Fixes #472 Added support for using [Writer Palmyra](https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-writer-palmyra.html) models with Invoke API via ChatBedrock. As a note, Writer Palmyra X4/X5 model prompt formatting is not publicly documented - Llama/Mistral format seems to work without issues for now.
1 parent 73b8d02 commit d4185eb

File tree

3 files changed

+97
-6
lines changed

3 files changed

+97
-6
lines changed

libs/aws/langchain_aws/chat_models/bedrock.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,28 @@ def convert_messages_to_prompt_deepseek(messages: List[BaseMessage]) -> str:
226226
return prompt
227227

228228

229+
def _convert_one_message_to_text_writer(message: BaseMessage) -> str:
230+
if isinstance(message, ChatMessage):
231+
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
232+
elif isinstance(message, HumanMessage):
233+
message_text = f"[INST] {message.content} [/INST]"
234+
elif isinstance(message, AIMessage):
235+
message_text = f"{message.content}"
236+
elif isinstance(message, SystemMessage):
237+
message_text = f"<<SYS>> {message.content} <</SYS>>"
238+
else:
239+
raise ValueError(f"Got unknown type {message}")
240+
return message_text
241+
242+
243+
def convert_messages_to_prompt_writer(messages: List[BaseMessage]) -> str:
244+
"""Convert a list of messages to a prompt for Writer."""
245+
246+
return "\n".join(
247+
[_convert_one_message_to_text_llama(message) for message in messages]
248+
)
249+
250+
229251
def _format_image(image_url: str) -> Dict:
230252
"""
231253
Formats an image of format data:image/jpeg;base64,{b64_string}
@@ -553,6 +575,8 @@ def convert_messages_to_prompt(
553575
human_prompt="\n\nUser:",
554576
ai_prompt="\n\nBot:",
555577
)
578+
elif provider == "writer":
579+
prompt = convert_messages_to_prompt_writer(messages=messages)
556580
else:
557581
raise NotImplementedError(
558582
f"Provider {provider} model does not support chat."

libs/aws/langchain_aws/llms/bedrock.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def _stream_response_to_generation_chunk(
176176
return GenerationChunk(
177177
text=(
178178
stream_response[output_key]
179-
if provider not in ["mistral", "deepseek"]
179+
if provider not in ["mistral", "deepseek", "writer"]
180180
else stream_response[output_key][0]["text"]
181181
),
182182
generation_info=generation_info,
@@ -273,6 +273,7 @@ class LLMInputOutputAdapter:
273273
"deepseek": "choices",
274274
"meta": "generation",
275275
"mistral": "outputs",
276+
"writer": "choices"
276277
}
277278

278279
@classmethod
@@ -363,7 +364,7 @@ def prepare_input(
363364
if temperature is not None:
364365
input_body["temperature"] = temperature
365366

366-
elif provider in ("ai21", "cohere", "meta", "mistral", "deepseek"):
367+
elif provider in ("ai21", "cohere", "meta", "mistral", "deepseek", "writer"):
367368
input_body["prompt"] = prompt
368369
if max_tokens:
369370
if provider == "cohere":
@@ -374,6 +375,8 @@ def prepare_input(
374375
input_body["max_tokens"] = max_tokens
375376
elif provider == "deepseek":
376377
input_body["max_tokens"] = max_tokens
378+
elif provider == "writer":
379+
input_body["max_tokens"] = max_tokens
377380
else:
378381
# TODO: Add AI21 support, param depends on specific model.
379382
pass
@@ -429,16 +432,16 @@ def prepare_output(cls, provider: str, response: Any) -> dict:
429432
tool_calls = extract_tool_calls(content)
430433

431434
else:
432-
if provider == "ai21":
435+
if provider in ["deepseek", "writer"]:
436+
text = response_body.get("choices")[0].get("text")
437+
elif provider == "ai21":
433438
text = response_body.get("completions")[0].get("data").get("text")
434439
elif provider == "cohere":
435440
text = response_body.get("generations")[0].get("text")
436441
elif provider == "meta":
437442
text = response_body.get("generation")
438443
elif provider == "mistral":
439444
text = response_body.get("outputs")[0].get("text")
440-
elif provider == "deepseek":
441-
text = response_body.get("choices")[0].get("text")
442445
else:
443446
text = response_body.get("results")[0].get("outputText")
444447

@@ -493,7 +496,10 @@ def prepare_output_stream(
493496

494497
chunk_obj = json.loads(chunk.get("bytes").decode())
495498

496-
if provider == "cohere" and (
499+
if provider == "writer" and chunk_obj == "[DONE]":
500+
return
501+
502+
elif provider == "cohere" and (
497503
chunk_obj["is_finished"] or chunk_obj[output_key] == "<EOS_TOKEN>"
498504
):
499505
return

libs/aws/tests/unit_tests/llms/test_bedrock.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,21 @@ def test__human_assistant_format() -> None:
280280
{"chunk": {"bytes": b'{"choices": [{"text": "you.","stop_reason": "stop"}]}'}},
281281
]
282282

283+
MOCK_STREAMING_RESPONSE_WRITER = [
284+
{"chunk": {'bytes': b'{"id":"cmpl-ec61121fa19443caa7f614bde08e926c",'
285+
b'"object":"text_completion",'
286+
b'"created":1747106231,'
287+
b'"model":"writer.palmyra-x5-v1:0",'
288+
b'"choices":[{"index":0,"text":"Hel","logprobs":null,"finish_reason":null,"stop_reason":null}],'
289+
b'"usage":null}'}},
290+
{"chunk": {'bytes': b'{"id":"cmpl-ec61121fa19443caa7f614bde08e926c",'
291+
b'"object":"text_completion",'
292+
b'"created":1747106231,'
293+
b'"model":"writer.palmyra-x5-v1:0",'
294+
b'"choices":[{"index":0,"text":"lo.","logprobs":null,"finish_reason":"length","stop_reason":null}],'
295+
b'"usage":null}'}},
296+
{"chunk": {'bytes': b'"[DONE]"'}},
297+
]
283298

284299
async def async_gen_mock_streaming_response() -> AsyncGenerator[Dict, None]:
285300
for item in MOCK_STREAMING_RESPONSE:
@@ -372,6 +387,31 @@ def deepseek_streaming_response():
372387
return response
373388

374389

390+
@pytest.fixture
391+
def writer_response():
392+
body = MagicMock()
393+
body.read.return_value = json.dumps(
394+
{'choices': [{'text': ' This is the Writer output text.'}]}
395+
).encode()
396+
response = dict(
397+
body=body,
398+
ResponseMetadata={
399+
"HTTPHeaders": {
400+
"x-amzn-bedrock-input-token-count": "17",
401+
"x-amzn-bedrock-output-token-count": "8",
402+
}
403+
},
404+
)
405+
406+
return response
407+
408+
409+
@pytest.fixture
410+
def writer_streaming_response():
411+
response = dict(body=MOCK_STREAMING_RESPONSE_WRITER)
412+
return response
413+
414+
375415
@pytest.fixture
376416
def cohere_response():
377417
body = MagicMock()
@@ -486,6 +526,27 @@ def test_prepare_output_stream_for_deepseek(deepseek_streaming_response) -> None
486526
assert results[1] == "you."
487527

488528

529+
def test_prepare_output_for_writer(writer_response):
530+
result = LLMInputOutputAdapter.prepare_output("writer", writer_response)
531+
assert result["text"] == " This is the Writer output text."
532+
assert result["usage"]["prompt_tokens"] == 17
533+
assert result["usage"]["completion_tokens"] == 8
534+
assert result["usage"]["total_tokens"] == 25
535+
assert result["stop_reason"] is None
536+
537+
538+
def test_prepare_output_stream_for_writer(writer_streaming_response) -> None:
539+
results = [
540+
chunk.text
541+
for chunk in LLMInputOutputAdapter.prepare_output_stream(
542+
"writer", writer_streaming_response
543+
)
544+
]
545+
546+
assert results[0] == "Hel"
547+
assert results[1] == "lo."
548+
549+
489550
def test_prepare_output_for_cohere(cohere_response):
490551
result = LLMInputOutputAdapter.prepare_output("cohere", cohere_response)
491552
assert result["text"] == "This is the Cohere output text."

0 commit comments

Comments
 (0)