Skip to content

Commit 306fbbb

Browse files
Add support for reasoning_content field in chat completion messages for DeepSeek R1 (#925)
* support deepseek field "reasoning_content" * support deepseek field "reasoning_content" * Comment ends in a period (godot) * add comment on field reasoning_content * fix go lint error * chore: trigger CI * make field "content" in MarshalJSON function omitempty * remove reasoning_content in TestO1ModelChatCompletions func * feat: Add test and handler for deepseek-reasoner chat model completions, including support for reasoning content in responses. * feat: Add test and handler for deepseek-reasoner chat model completions, including support for reasoning content in responses. * feat: Add test and handler for deepseek-reasoner chat model completions, including support for reasoning content in responses.
1 parent 658beda commit 306fbbb

File tree

4 files changed

+128
-33
lines changed

4 files changed

+128
-33
lines changed

chat.go

+42-32
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,12 @@ type ChatCompletionMessage struct {
104104
// - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
105105
Name string `json:"name,omitempty"`
106106

107+
// This property is used for the "reasoning" feature supported by deepseek-reasoner
108+
// which is not in the official documentation.
109+
// the doc from deepseek:
110+
// - https://api-docs.deepseek.com/api/create-chat-completion#responses
111+
ReasoningContent string `json:"reasoning_content,omitempty"`
112+
107113
FunctionCall *FunctionCall `json:"function_call,omitempty"`
108114

109115
// For Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls.
@@ -119,56 +125,60 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
119125
}
120126
if len(m.MultiContent) > 0 {
121127
msg := struct {
122-
Role string `json:"role"`
123-
Content string `json:"-"`
124-
Refusal string `json:"refusal,omitempty"`
125-
MultiContent []ChatMessagePart `json:"content,omitempty"`
126-
Name string `json:"name,omitempty"`
127-
FunctionCall *FunctionCall `json:"function_call,omitempty"`
128-
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
129-
ToolCallID string `json:"tool_call_id,omitempty"`
128+
Role string `json:"role"`
129+
Content string `json:"-"`
130+
Refusal string `json:"refusal,omitempty"`
131+
MultiContent []ChatMessagePart `json:"content,omitempty"`
132+
Name string `json:"name,omitempty"`
133+
ReasoningContent string `json:"reasoning_content,omitempty"`
134+
FunctionCall *FunctionCall `json:"function_call,omitempty"`
135+
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
136+
ToolCallID string `json:"tool_call_id,omitempty"`
130137
}(m)
131138
return json.Marshal(msg)
132139
}
133140

134141
msg := struct {
135-
Role string `json:"role"`
136-
Content string `json:"content,omitempty"`
137-
Refusal string `json:"refusal,omitempty"`
138-
MultiContent []ChatMessagePart `json:"-"`
139-
Name string `json:"name,omitempty"`
140-
FunctionCall *FunctionCall `json:"function_call,omitempty"`
141-
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
142-
ToolCallID string `json:"tool_call_id,omitempty"`
142+
Role string `json:"role"`
143+
Content string `json:"content,omitempty"`
144+
Refusal string `json:"refusal,omitempty"`
145+
MultiContent []ChatMessagePart `json:"-"`
146+
Name string `json:"name,omitempty"`
147+
ReasoningContent string `json:"reasoning_content,omitempty"`
148+
FunctionCall *FunctionCall `json:"function_call,omitempty"`
149+
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
150+
ToolCallID string `json:"tool_call_id,omitempty"`
143151
}(m)
144152
return json.Marshal(msg)
145153
}
146154

147155
func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
148156
msg := struct {
149-
Role string `json:"role"`
150-
Content string `json:"content,omitempty"`
151-
Refusal string `json:"refusal,omitempty"`
152-
MultiContent []ChatMessagePart
153-
Name string `json:"name,omitempty"`
154-
FunctionCall *FunctionCall `json:"function_call,omitempty"`
155-
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
156-
ToolCallID string `json:"tool_call_id,omitempty"`
157+
Role string `json:"role"`
158+
Content string `json:"content"`
159+
Refusal string `json:"refusal,omitempty"`
160+
MultiContent []ChatMessagePart
161+
Name string `json:"name,omitempty"`
162+
ReasoningContent string `json:"reasoning_content,omitempty"`
163+
FunctionCall *FunctionCall `json:"function_call,omitempty"`
164+
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
165+
ToolCallID string `json:"tool_call_id,omitempty"`
157166
}{}
158167

159168
if err := json.Unmarshal(bs, &msg); err == nil {
160169
*m = ChatCompletionMessage(msg)
161170
return nil
162171
}
163172
multiMsg := struct {
164-
Role string `json:"role"`
165-
Content string
166-
Refusal string `json:"refusal,omitempty"`
167-
MultiContent []ChatMessagePart `json:"content"`
168-
Name string `json:"name,omitempty"`
169-
FunctionCall *FunctionCall `json:"function_call,omitempty"`
170-
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
171-
ToolCallID string `json:"tool_call_id,omitempty"`
173+
Role string `json:"role"`
174+
Content string
175+
Refusal string `json:"refusal,omitempty"`
176+
MultiContent []ChatMessagePart `json:"content"`
177+
Name string `json:"name,omitempty"`
178+
ReasoningContent string `json:"reasoning_content,omitempty"`
179+
FunctionCall *FunctionCall `json:"function_call,omitempty"`
180+
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
181+
ToolCallID string `json:"tool_call_id,omitempty"`
172182
}{}
173183
if err := json.Unmarshal(bs, &multiMsg); err != nil {
174184
return err

chat_stream.go

+6
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ type ChatCompletionStreamChoiceDelta struct {
1111
FunctionCall *FunctionCall `json:"function_call,omitempty"`
1212
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
1313
Refusal string `json:"refusal,omitempty"`
14+
15+
// This property is used for the "reasoning" feature supported by deepseek-reasoner
16+
// which is not in the official documentation.
17+
// the doc from deepseek:
18+
// - https://api-docs.deepseek.com/api/create-chat-completion#responses
19+
ReasoningContent string `json:"reasoning_content,omitempty"`
1420
}
1521

1622
type ChatCompletionStreamChoiceLogprobs struct {

chat_test.go

+79
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,23 @@ func TestO3ModelChatCompletions(t *testing.T) {
411411
checks.NoError(t, err, "CreateChatCompletion error")
412412
}
413413

414+
func TestDeepseekR1ModelChatCompletions(t *testing.T) {
415+
client, server, teardown := setupOpenAITestServer()
416+
defer teardown()
417+
server.RegisterHandler("/v1/chat/completions", handleDeepseekR1ChatCompletionEndpoint)
418+
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
419+
Model: "deepseek-reasoner",
420+
MaxCompletionTokens: 100,
421+
Messages: []openai.ChatCompletionMessage{
422+
{
423+
Role: openai.ChatMessageRoleUser,
424+
Content: "Hello!",
425+
},
426+
},
427+
})
428+
checks.NoError(t, err, "CreateChatCompletion error")
429+
}
430+
414431
// TestCompletions Tests the completions endpoint of the API using the mocked server.
415432
func TestChatCompletionsWithHeaders(t *testing.T) {
416433
client, server, teardown := setupOpenAITestServer()
@@ -822,6 +839,68 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
822839
fmt.Fprintln(w, string(resBytes))
823840
}
824841

842+
func handleDeepseekR1ChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
843+
var err error
844+
var resBytes []byte
845+
846+
// completions only accepts POST requests
847+
if r.Method != "POST" {
848+
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
849+
}
850+
var completionReq openai.ChatCompletionRequest
851+
if completionReq, err = getChatCompletionBody(r); err != nil {
852+
http.Error(w, "could not read request", http.StatusInternalServerError)
853+
return
854+
}
855+
res := openai.ChatCompletionResponse{
856+
ID: strconv.Itoa(int(time.Now().Unix())),
857+
Object: "test-object",
858+
Created: time.Now().Unix(),
859+
// would be nice to validate Model during testing, but
860+
// this may not be possible with how much upkeep
861+
// would be required / wouldn't make much sense
862+
Model: completionReq.Model,
863+
}
864+
// create completions
865+
n := completionReq.N
866+
if n == 0 {
867+
n = 1
868+
}
869+
if completionReq.MaxCompletionTokens == 0 {
870+
completionReq.MaxCompletionTokens = 1000
871+
}
872+
for i := 0; i < n; i++ {
873+
reasoningContent := "User says hello! And I need to reply"
874+
completionStr := strings.Repeat("a", completionReq.MaxCompletionTokens-numTokens(reasoningContent))
875+
res.Choices = append(res.Choices, openai.ChatCompletionChoice{
876+
Message: openai.ChatCompletionMessage{
877+
Role: openai.ChatMessageRoleAssistant,
878+
ReasoningContent: reasoningContent,
879+
Content: completionStr,
880+
},
881+
Index: i,
882+
})
883+
}
884+
inputTokens := numTokens(completionReq.Messages[0].Content) * n
885+
completionTokens := completionReq.MaxTokens * n
886+
res.Usage = openai.Usage{
887+
PromptTokens: inputTokens,
888+
CompletionTokens: completionTokens,
889+
TotalTokens: inputTokens + completionTokens,
890+
}
891+
resBytes, _ = json.Marshal(res)
892+
w.Header().Set(xCustomHeader, xCustomHeaderValue)
893+
for k, v := range rateLimitHeaders {
894+
switch val := v.(type) {
895+
case int:
896+
w.Header().Set(k, strconv.Itoa(val))
897+
default:
898+
w.Header().Set(k, fmt.Sprintf("%s", v))
899+
}
900+
}
901+
fmt.Fprintln(w, string(resBytes))
902+
}
903+
825904
// getChatCompletionBody Returns the body of the request to create a completion.
826905
func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error) {
827906
completion := openai.ChatCompletionRequest{}

openai_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, tea
2929

3030
// numTokens Returns the number of GPT-3 encoded tokens in the given text.
3131
// This function approximates based on the rule of thumb stated by OpenAI:
32-
// https://beta.openai.com/tokenizer/
32+
// https://beta.openai.com/tokenizer.
3333
//
3434
// TODO: implement an actual tokenizer for GPT-3 and Codex (once available).
3535
func numTokens(s string) int {

0 commit comments

Comments
 (0)