Skip to content

Commit 716117c

Browse files
committed
fix PromptTemplate for Llama3
1 parent 5f1db32 commit 716117c

File tree

1 file changed

+30
-12
lines changed

1 file changed

+30
-12
lines changed

packages/cdk/lambda/utils/models.ts

+30-12
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ const TITAN_TEXT_PROMPT: PromptTemplate = {
6767
eosToken: '',
6868
};
6969

70-
const LLAMA_PROMPT: PromptTemplate = {
70+
const LLAMA2_PROMPT: PromptTemplate = {
7171
prefix: '<s>[INST] ',
7272
suffix: ' [/INST]',
7373
join: '',
@@ -77,6 +77,16 @@ const LLAMA_PROMPT: PromptTemplate = {
7777
eosToken: '</s>',
7878
};
7979

80+
const LLAMA3_PROMPT: PromptTemplate = {
81+
prefix: '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n',
82+
suffix: ' [/INST]',
83+
join: '',
84+
user: '{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n',
85+
assistant: '{}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n',
86+
system: '{}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n',
87+
eosToken: '',
88+
};
89+
8090
const MISTRAL_PROMPT: PromptTemplate = {
8191
prefix: '<s>[INST] ',
8292
suffix: ' [/INST]',
@@ -202,9 +212,17 @@ const createBodyTextTitanText = (messages: UnrecordedMessage[]) => {
202212
return JSON.stringify(body);
203213
};
204214

205-
const createBodyTextLlama = (messages: UnrecordedMessage[]) => {
215+
const createBodyTextLlama2 = (messages: UnrecordedMessage[]) => {
216+
const body: LlamaParams = {
217+
prompt: generatePrompt(LLAMA2_PROMPT, messages),
218+
...LLAMA_DEFAULT_PARAMS,
219+
};
220+
return JSON.stringify(body);
221+
};
222+
223+
const createBodyTextLlama3 = (messages: UnrecordedMessage[]) => {
206224
const body: LlamaParams = {
207-
prompt: generatePrompt(LLAMA_PROMPT, messages),
225+
prompt: generatePrompt(LLAMA3_PROMPT, messages),
208226
...LLAMA_DEFAULT_PARAMS,
209227
};
210228
return JSON.stringify(body);
@@ -354,23 +372,23 @@ export const BEDROCK_MODELS: {
354372
extractOutputText: extractOutputTextTitanText,
355373
},
356374
'meta.llama3-8b-instruct-v1:0': {
357-
promptTemplate: LLAMA_PROMPT,
358-
createBodyText: createBodyTextLlama,
375+
promptTemplate: LLAMA3_PROMPT,
376+
createBodyText: createBodyTextLlama3,
359377
extractOutputText: extractOutputTextLlama,
360378
},
361379
'meta.llama3-70b-instruct-v1:0': {
362-
promptTemplate: LLAMA_PROMPT,
363-
createBodyText: createBodyTextLlama,
380+
promptTemplate: LLAMA3_PROMPT,
381+
createBodyText: createBodyTextLlama3,
364382
extractOutputText: extractOutputTextLlama,
365383
},
366384
'meta.llama2-13b-chat-v1': {
367-
promptTemplate: LLAMA_PROMPT,
368-
createBodyText: createBodyTextLlama,
385+
promptTemplate: LLAMA2_PROMPT,
386+
createBodyText: createBodyTextLlama2,
369387
extractOutputText: extractOutputTextLlama,
370388
},
371389
'meta.llama2-70b-chat-v1': {
372-
promptTemplate: LLAMA_PROMPT,
373-
createBodyText: createBodyTextLlama,
390+
promptTemplate: LLAMA2_PROMPT,
391+
createBodyText: createBodyTextLlama2,
374392
extractOutputText: extractOutputTextLlama,
375393
},
376394
'mistral.mistral-7b-instruct-v0:2': {
@@ -412,7 +430,7 @@ export const BEDROCK_IMAGE_GEN_MODELS: {
412430

413431
export const getSageMakerModelTemplate = (model: string): PromptTemplate => {
414432
if (model.includes('llama-2')) {
415-
return LLAMA_PROMPT;
433+
return LLAMA2_PROMPT;
416434
} else if (model.includes('bilingual-rinna')) {
417435
return BILINGUAL_RINNA_PROMPT;
418436
} else if (model.includes('rinna')) {

0 commit comments

Comments
 (0)