Skip to content

Add function calling and structured output for Novita AI #8023

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/core_docs/docs/integrations/chat/novita.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
"### Model features\n",
"| [Tool calling](/docs/how_to/tool_calling) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n",
"| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n",
"| | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | |"
"| | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | |"
]
},
{
Expand Down
56 changes: 29 additions & 27 deletions libs/langchain-community/src/chat_models/novita.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,8 @@ import {
} from "@langchain/openai";
import { getEnvironmentVariable } from "@langchain/core/utils/env";

type NovitaUnsupportedArgs =
| "frequencyPenalty"
| "presencePenalty"
| "logitBias"
| "functions";

type NovitaUnsupportedCallOptions = "functions" | "function_call";

export interface ChatNovitaCallOptions
extends Omit<ChatOpenAICallOptions, NovitaUnsupportedCallOptions> {
response_format: {
type: "json_object";
schema: Record<string, unknown>;
};
}

export interface ChatNovitaInput
extends Omit<OpenAIChatInput, "openAIApiKey" | NovitaUnsupportedArgs>,
extends Omit<OpenAIChatInput, "openAIApiKey">,
BaseChatModelParams {
/**
* Novita API key
Expand All @@ -45,7 +29,7 @@ export interface ChatNovitaInput
/**
* Novita chat model implementation
*/
export class ChatNovitaAI extends ChatOpenAI<ChatNovitaCallOptions> {
export class ChatNovitaAI extends ChatOpenAI<ChatOpenAICallOptions> {
static lc_name() {
return "ChatNovita";
}
Expand All @@ -65,7 +49,7 @@ export class ChatNovitaAI extends ChatOpenAI<ChatNovitaCallOptions> {

constructor(
fields?: Partial<
Omit<OpenAIChatInput, "openAIApiKey" | NovitaUnsupportedArgs>
Omit<OpenAIChatInput, "openAIApiKey">
> &
BaseChatModelParams & {
novitaApiKey?: string;
Expand All @@ -85,7 +69,7 @@ export class ChatNovitaAI extends ChatOpenAI<ChatNovitaCallOptions> {

super({
...fields,
model: fields?.model || "gryphe/mythomax-l2-13b",
model: fields?.model || "qwen/qwen-2.5-72b-instruct",
apiKey: novitaApiKey,
configuration: {
baseURL: "https://api.novita.ai/v3/openai/",
Expand Down Expand Up @@ -133,15 +117,33 @@ export class ChatNovitaAI extends ChatOpenAI<ChatNovitaCallOptions> {
| AsyncIterable<OpenAIClient.Chat.Completions.ChatCompletionChunk>
| OpenAIClient.Chat.Completions.ChatCompletion
> {
delete request.frequency_penalty;
delete request.presence_penalty;
delete request.logit_bias;
delete request.functions;
if (request.response_format) {
if (request.response_format.type === "json_object") {
request.response_format = {
type: "json_object",
};
} else if ('json_schema' in request.response_format) {
const json_schema = request.response_format.json_schema;
request.response_format = {
type: "json_schema",
json_schema,
};
}
}

if (request.stream === true) {
return super.completionWithRetry(request, options);
if (!request.model) {
request.model = "qwen/qwen-2.5-72b-instruct";
}

return super.completionWithRetry(request, options);
try {
if (request.stream === true) {
return super.completionWithRetry(request, options);
}

return super.completionWithRetry(request, options);
} catch (error: any) {
console.error("Novita API call failed:", error.message || error);
throw error;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,22 @@ import {
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
} from "@langchain/core/prompts";
import { formatToOpenAITool } from "@langchain/openai";
import { StructuredTool } from "@langchain/core/tools";
import { z } from "zod";
import { ChatNovitaAI } from "../novita.js";

describe("ChatNovitaAI", () => {
test("invoke", async () => {
const chat = new ChatNovitaAI();
const message = new HumanMessage("Hello!");
const message = new HumanMessage("Hello! Who are you?");
const res = await chat.invoke([message]);
expect(res.content.length).toBeGreaterThan(10);
});

test("generate", async () => {
const chat = new ChatNovitaAI();
const message = new HumanMessage("Hello!");
const message = new HumanMessage("Hello! Who are you?");
const res = await chat.generate([[message]]);
expect(res.generations[0][0].text.length).toBeGreaterThan(10);
});
Expand Down Expand Up @@ -53,7 +56,6 @@ describe("ChatNovitaAI", () => {
test("prompt templates", async () => {
const chat = new ChatNovitaAI();

// PaLM doesn't support translation yet
const systemPrompt = PromptTemplate.fromTemplate(
"You are a helpful assistant who must always respond like a {job}."
);
Expand Down Expand Up @@ -88,4 +90,54 @@ describe("ChatNovitaAI", () => {
]);
expect(responseA.generations[0][0].text.length).toBeGreaterThan(10);
});
});

test("JSON mode", async () => {
const chat = new ChatNovitaAI().bind({
response_format: {
type: "json_object"
},
});
const prompt = ChatPromptTemplate.fromMessages([
["system", "You are a helpful assistant who responds in JSON. You must return a JSON object with an 'orderedArray' property containing the numbers in descending order."],
["human", "Please list this output in order of DESC [1, 4, 2, 8]."],
]);
const res = await prompt.pipe(chat).invoke({});
expect(typeof res.content).toBe("string");
expect(JSON.parse(res.content as string)).toMatchObject({
orderedArray: expect.any(Array),
});
});

test("Tool calls", async () => {
class CalculatorTool extends StructuredTool {
name = "Calculator";

schema = z.object({
a: z.number(),
b: z.number(),
});

description = "A simple calculator tool.";

constructor() {
super();
}

async _call(input: { a: number; b: number }) {
return JSON.stringify({ total: input.a + input.b });
}
}
const tool = formatToOpenAITool(new CalculatorTool());
const chat = new ChatNovitaAI().bind({
tools: [tool],
});
const prompt = ChatPromptTemplate.fromMessages([
["system", "You are a helpful assistant."],
["human", "What is 1273926 times 27251?"],
]);
const res = await prompt.pipe(chat).invoke({});
expect(res.tool_calls?.length).toBeGreaterThan(0);
expect(res.tool_calls?.[0].args)
.toMatchObject({ a: expect.any(Number), b: expect.any(Number) });
});
});