diff --git a/chat.go b/chat.go index 995860c40..7112dc7b7 100644 --- a/chat.go +++ b/chat.go @@ -104,6 +104,12 @@ type ChatCompletionMessage struct { // - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb Name string `json:"name,omitempty"` + // This property is used for the "reasoning" feature supported by deepseek-reasoner + // which is not in the official documentation. + // the doc from deepseek: + // - https://api-docs.deepseek.com/api/create-chat-completion#responses + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` // For Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls. @@ -119,41 +125,44 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) { } if len(m.MultiContent) > 0 { msg := struct { - Role string `json:"role"` - Content string `json:"-"` - Refusal string `json:"refusal,omitempty"` - MultiContent []ChatMessagePart `json:"content,omitempty"` - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"-"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart `json:"content,omitempty"` + Name string `json:"name,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }(m) return json.Marshal(msg) } msg := struct { - Role string `json:"role"` - Content string `json:"content,omitempty"` - Refusal string `json:"refusal,omitempty"` - MultiContent []ChatMessagePart `json:"-"` - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"content,omitempty"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart `json:"-"` + Name string `json:"name,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }(m) return json.Marshal(msg) } func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { msg := struct { - Role string `json:"role"` - Content string `json:"content,omitempty"` - Refusal string `json:"refusal,omitempty"` - MultiContent []ChatMessagePart - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"content"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart + Name string `json:"name,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }{} if err := json.Unmarshal(bs, &msg); err == nil { @@ -161,14 +170,15 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { return nil } multiMsg := struct { - Role string `json:"role"` - Content string - Refusal string `json:"refusal,omitempty"` - MultiContent []ChatMessagePart `json:"content"` - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart `json:"content"` + Name string `json:"name,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }{} if err := json.Unmarshal(bs, &multiMsg); err != nil { return err diff --git a/chat_stream.go b/chat_stream.go index 525b4457a..80d16cc63 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -11,6 +11,12 @@ type ChatCompletionStreamChoiceDelta struct { FunctionCall *FunctionCall `json:"function_call,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"` Refusal string `json:"refusal,omitempty"` + + // This property is used for the "reasoning" feature supported by deepseek-reasoner + // which is not in the official documentation. + // the doc from deepseek: + // - https://api-docs.deepseek.com/api/create-chat-completion#responses + ReasoningContent string `json:"reasoning_content,omitempty"` } type ChatCompletionStreamChoiceLogprobs struct { diff --git a/chat_test.go b/chat_test.go index e90142da6..514706c96 100644 --- a/chat_test.go +++ b/chat_test.go @@ -411,6 +411,23 @@ func TestO3ModelChatCompletions(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion error") } +func TestDeepseekR1ModelChatCompletions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleDeepseekR1ChatCompletionEndpoint) + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + Model: "deepseek-reasoner", + MaxCompletionTokens: 100, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateChatCompletion error") +} + // TestCompletions Tests the completions endpoint of the API using the mocked server. func TestChatCompletionsWithHeaders(t *testing.T) { client, server, teardown := setupOpenAITestServer() @@ -822,6 +839,68 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, string(resBytes)) } +func handleDeepseekR1ChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { + var err error + var resBytes []byte + + // completions only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + var completionReq openai.ChatCompletionRequest + if completionReq, err = getChatCompletionBody(r); err != nil { + http.Error(w, "could not read request", http.StatusInternalServerError) + return + } + res := openai.ChatCompletionResponse{ + ID: strconv.Itoa(int(time.Now().Unix())), + Object: "test-object", + Created: time.Now().Unix(), + // would be nice to validate Model during testing, but + // this may not be possible with how much upkeep + // would be required / wouldn't make much sense + Model: completionReq.Model, + } + // create completions + n := completionReq.N + if n == 0 { + n = 1 + } + if completionReq.MaxCompletionTokens == 0 { + completionReq.MaxCompletionTokens = 1000 + } + for i := 0; i < n; i++ { + reasoningContent := "User says hello! And I need to reply" + completionStr := strings.Repeat("a", completionReq.MaxCompletionTokens-numTokens(reasoningContent)) + res.Choices = append(res.Choices, openai.ChatCompletionChoice{ + Message: openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, + ReasoningContent: reasoningContent, + Content: completionStr, + }, + Index: i, + }) + } + inputTokens := numTokens(completionReq.Messages[0].Content) * n + completionTokens := completionReq.MaxTokens * n + res.Usage = openai.Usage{ + PromptTokens: inputTokens, + CompletionTokens: completionTokens, + TotalTokens: inputTokens + completionTokens, + } + resBytes, _ = json.Marshal(res) + w.Header().Set(xCustomHeader, xCustomHeaderValue) + for k, v := range rateLimitHeaders { + switch val := v.(type) { + case int: + w.Header().Set(k, strconv.Itoa(val)) + default: + w.Header().Set(k, fmt.Sprintf("%s", v)) + } + } + fmt.Fprintln(w, string(resBytes)) +} + // getChatCompletionBody Returns the body of the request to create a completion. func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error) { completion := openai.ChatCompletionRequest{} diff --git a/openai_test.go b/openai_test.go index 6c26eebd1..a55f3a858 100644 --- a/openai_test.go +++ b/openai_test.go @@ -29,7 +29,7 @@ func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, tea // numTokens Returns the number of GPT-3 encoded tokens in the given text. // This function approximates based on the rule of thumb stated by OpenAI: -// https://beta.openai.com/tokenizer/ +// https://beta.openai.com/tokenizer. // // TODO: implement an actual tokenizer for GPT-3 and Codex (once available). func numTokens(s string) int {