Skip to content

Commit 68b9dae

Browse files
authored
Feat: mcp server (#7084)
### What problem does this PR solve? Add MCP support with a client example. Issue link: #4344 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
1 parent 9b956ac commit 68b9dae

File tree

7 files changed

+403
-3
lines changed

7 files changed

+403
-3
lines changed

Dockerfile

+1
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ COPY agent agent
198198
COPY graphrag graphrag
199199
COPY agentic_reasoning agentic_reasoning
200200
COPY pyproject.toml uv.lock ./
201+
COPY mcp mcp
201202

202203
COPY docker/service_conf.yaml.template ./conf/service_conf.yaml.template
203204
COPY docker/entrypoint.sh ./

docker/docker-compose.yml

+8
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,21 @@ services:
88
mysql:
99
condition: service_healthy
1010
image: ${RAGFLOW_IMAGE}
11+
# example to setup MCP server
12+
# command:
13+
# - --enable-mcpserver
14+
# - --mcp-host=0.0.0.0
15+
# - --mcp-port=9382
16+
# - --mcp-base-url=http://127.0.0.1:9380
17+
# - --mcp-script-path=/ragflow/mcp/server/server.py
1118
container_name: ragflow-server
1219
ports:
1320
- ${SVR_HTTP_PORT}:9380
1421
- 80:80
1522
- 443:443
1623
- 5678:5678
1724
- 5679:5679
25+
- 9382:9382 # entry for MCP (host_port:docker_port). The docker_port should match with the value you set for `mcp-port` above
1826
volumes:
1927
- ./ragflow-logs:/ragflow/logs
2028
- ./nginx/ragflow.conf:/etc/nginx/conf.d/ragflow.conf

docker/entrypoint.sh

+41
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ function usage() {
1010
echo
1111
echo " --disable-webserver Disables the web server (nginx + ragflow_server)."
1212
echo " --disable-taskexecutor Disables task executor workers."
13+
echo " --enable-mcpserver Enables the MCP server."
1314
echo " --consumer-no-beg=<num> Start range for consumers (if using range-based)."
1415
echo " --consumer-no-end=<num> End range for consumers (if using range-based)."
1516
echo " --workers=<num> Number of task executors to run (if range is not used)."
@@ -19,15 +20,22 @@ function usage() {
1920
echo " $0 --disable-taskexecutor"
2021
echo " $0 --disable-webserver --consumer-no-beg=0 --consumer-no-end=5"
2122
echo " $0 --disable-webserver --workers=2 --host-id=myhost123"
23+
echo " $0 --enable-mcpserver"
2224
exit 1
2325
}
2426

2527
ENABLE_WEBSERVER=1 # Default to enable web server
2628
ENABLE_TASKEXECUTOR=1 # Default to enable task executor
29+
ENABLE_MCP_SERVER=0
2730
CONSUMER_NO_BEG=0
2831
CONSUMER_NO_END=0
2932
WORKERS=1
3033

34+
MCP_HOST="127.0.0.1"
35+
MCP_PORT=9382
36+
MCP_BASE_URL="http://127.0.0.1:9380"
37+
MCP_SCRIPT_PATH="/ragflow/mcp/server/server.py"
38+
3139
# -----------------------------------------------------------------------------
3240
# Host ID logic:
3341
# 1. By default, use the system hostname if length <= 32
@@ -53,6 +61,26 @@ for arg in "$@"; do
5361
ENABLE_TASKEXECUTOR=0
5462
shift
5563
;;
64+
--enable-mcpserver)
65+
ENABLE_MCP_SERVER=1
66+
shift
67+
;;
68+
--mcp-host=*)
69+
MCP_HOST="${arg#*=}"
70+
shift
71+
;;
72+
--mcp-port=*)
73+
MCP_PORT="${arg#*=}"
74+
shift
75+
;;
76+
--mcp-base-url=*)
77+
MCP_BASE_URL="${arg#*=}"
78+
shift
79+
;;
80+
--mcp-script-path=*)
81+
MCP_SCRIPT_PATH="${arg#*=}"
82+
shift
83+
;;
5684
--consumer-no-beg=*)
5785
CONSUMER_NO_BEG="${arg#*=}"
5886
shift
@@ -105,6 +133,14 @@ function task_exe() {
105133
done
106134
}
107135

136+
function start_mcp_server() {
137+
echo "Starting MCP Server on ${MCP_HOST}:${MCP_PORT} with base URL ${MCP_BASE_URL}..."
138+
"$PY" "${MCP_SCRIPT_PATH}" \
139+
--host="${MCP_HOST}" \
140+
--port="${MCP_PORT}" \
141+
--base_url="${MCP_BASE_URL}" &
142+
}
143+
108144
# -----------------------------------------------------------------------------
109145
# Start components based on flags
110146
# -----------------------------------------------------------------------------
@@ -119,6 +155,11 @@ if [[ "${ENABLE_WEBSERVER}" -eq 1 ]]; then
119155
done &
120156
fi
121157

158+
159+
if [[ "${ENABLE_MCP_SERVER}" -eq 1 ]]; then
160+
start_mcp_server
161+
fi
162+
122163
if [[ "${ENABLE_TASKEXECUTOR}" -eq 1 ]]; then
123164
if [[ "${CONSUMER_NO_END}" -gt "${CONSUMER_NO_BEG}" ]]; then
124165
echo "Starting task executors on host '${HOST_ID}' for IDs in [${CONSUMER_NO_BEG}, ${CONSUMER_NO_END})..."

mcp/client/client.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#
2+
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
18+
from mcp.client.session import ClientSession
19+
from mcp.client.sse import sse_client
20+
21+
22+
async def main():
23+
try:
24+
async with sse_client("http://localhost:9382/sse", headers={"api_key": "ragflow-IyMGI1ZDhjMTA2ZTExZjBiYTMyMGQ4Zm"}) as streams:
25+
async with ClientSession(
26+
streams[0],
27+
streams[1],
28+
) as session:
29+
await session.initialize()
30+
tools = await session.list_tools()
31+
print(f"{tools.tools=}")
32+
response = await session.call_tool(name="ragflow_retrival", arguments={"dataset_ids": ["ce3bb17cf27a11efa69751e139332ced"], "document_ids": [], "question": "How to install neovim?"})
33+
print(f"Tool response: {response.model_dump()}")
34+
35+
except Exception as e:
36+
print(e)
37+
38+
39+
if __name__ == "__main__":
40+
from anyio import run
41+
42+
run(main)

mcp/server/server.py

+234
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
#
2+
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
import json
18+
from collections.abc import AsyncIterator
19+
from contextlib import asynccontextmanager
20+
21+
import requests
22+
from starlette.applications import Starlette
23+
from starlette.middleware import Middleware
24+
from starlette.middleware.base import BaseHTTPMiddleware
25+
from starlette.responses import JSONResponse
26+
from starlette.routing import Mount, Route
27+
28+
import mcp.types as types
29+
from mcp.server.lowlevel import Server
30+
from mcp.server.sse import SseServerTransport
31+
32+
BASE_URL = "http://127.0.0.1:9380"
33+
HOST = "127.0.0.1"
34+
PORT = "9382"
35+
36+
37+
class RAGFlowConnector:
38+
def __init__(self, base_url: str, version="v1"):
39+
self.base_url = base_url
40+
self.version = version
41+
self.api_url = f"{self.base_url}/api/{self.version}"
42+
43+
def bind_api_key(self, api_key: str):
44+
self.api_key = api_key
45+
self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.api_key)}
46+
47+
def _post(self, path, json=None, stream=False, files=None):
48+
if not self.api_key:
49+
return None
50+
res = requests.post(url=self.api_url + path, json=json, headers=self.authorization_header, stream=stream, files=files)
51+
return res
52+
53+
def _get(self, path, params=None, json=None):
54+
res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header, json=json)
55+
return res
56+
57+
def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None):
58+
res = self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
59+
if not res:
60+
raise Exception([types.TextContent(type="text", text=res.get("Cannot process this operation."))])
61+
62+
res = res.json()
63+
if res.get("code") == 0:
64+
result_list = []
65+
for data in res["data"]:
66+
d = {"description": data["description"], "id": data["id"]}
67+
result_list.append(json.dumps(d, ensure_ascii=False))
68+
return "\n".join(result_list)
69+
return ""
70+
71+
def retrival(
72+
self, dataset_ids, document_ids=None, question="", page=1, page_size=30, similarity_threshold=0.2, vector_similarity_weight=0.3, top_k=1024, rerank_id: str | None = None, keyword: bool = False
73+
):
74+
if document_ids is None:
75+
document_ids = []
76+
data_json = {
77+
"page": page,
78+
"page_size": page_size,
79+
"similarity_threshold": similarity_threshold,
80+
"vector_similarity_weight": vector_similarity_weight,
81+
"top_k": top_k,
82+
"rerank_id": rerank_id,
83+
"keyword": keyword,
84+
"question": question,
85+
"dataset_ids": dataset_ids,
86+
"document_ids": document_ids,
87+
}
88+
# Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
89+
res = self._post("/retrieval", json=data_json)
90+
if not res:
91+
raise Exception([types.TextContent(type="text", text=res.get("Cannot process this operation."))])
92+
93+
res = res.json()
94+
if res.get("code") == 0:
95+
chunks = []
96+
for chunk_data in res["data"].get("chunks"):
97+
chunks.append(json.dumps(chunk_data, ensure_ascii=False))
98+
return [types.TextContent(type="text", text="\n".join(chunks))]
99+
raise Exception([types.TextContent(type="text", text=res.get("message"))])
100+
101+
102+
class RAGFlowCtx:
103+
def __init__(self, connector: RAGFlowConnector):
104+
self.conn = connector
105+
106+
107+
@asynccontextmanager
108+
async def server_lifespan(server: Server) -> AsyncIterator[dict]:
109+
ctx = RAGFlowCtx(RAGFlowConnector(base_url=BASE_URL))
110+
111+
try:
112+
yield {"ragflow_ctx": ctx}
113+
finally:
114+
pass
115+
116+
117+
app = Server("ragflow-server", lifespan=server_lifespan)
118+
sse = SseServerTransport("/messages/")
119+
120+
121+
@app.list_tools()
122+
async def list_tools() -> list[types.Tool]:
123+
ctx = app.request_context
124+
ragflow_ctx = ctx.lifespan_context["ragflow_ctx"]
125+
if not ragflow_ctx:
126+
raise ValueError("Get RAGFlow Context failed")
127+
connector = ragflow_ctx.conn
128+
129+
api_key = ctx.session._init_options.capabilities.experimental["headers"]["api_key"]
130+
if not api_key:
131+
raise ValueError("RAGFlow API_KEY is required.")
132+
connector.bind_api_key(api_key)
133+
134+
dataset_description = connector.list_datasets()
135+
136+
return [
137+
types.Tool(
138+
name="retrival",
139+
description="Retrieve relevant chunks from the RAGFlow retrieve interface based on the question, using the specified dataset_ids and optionally document_ids. Below is the list of all available datasets, including their descriptions and IDs. If you're unsure which datasets are relevant to the question, simply pass all dataset IDs to the function."
140+
+ dataset_description,
141+
inputSchema={
142+
"type": "object",
143+
"properties": {"dataset_ids": {"type": "array", "items": {"type": "string"}}, "documents_ids": {"type": "array", "items": {"type": "string"}}, "question": {"type": "string"}},
144+
"required": ["dataset_ids", "question"],
145+
},
146+
),
147+
]
148+
149+
150+
@app.call_tool()
151+
async def call_tool(name: str, arguments: dict) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
152+
ctx = app.request_context
153+
ragflow_ctx = ctx.lifespan_context["ragflow_ctx"]
154+
if not ragflow_ctx:
155+
raise ValueError("Get RAGFlow Context failed")
156+
connector = ragflow_ctx.conn
157+
158+
api_key = ctx.session._init_options.capabilities.experimental["headers"]["api_key"]
159+
if not api_key:
160+
raise ValueError("RAGFlow API_KEY is required.")
161+
connector.bind_api_key(api_key)
162+
163+
if name == "ragflow_retrival":
164+
return connector.retrival(dataset_ids=arguments["dataset_ids"], document_ids=arguments["document_ids"], question=arguments["question"])
165+
raise ValueError(f"Tool not found: {name}")
166+
167+
168+
async def handle_sse(request):
169+
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
170+
await app.run(streams[0], streams[1], app.create_initialization_options(experimental_capabilities={"headers": dict(request.headers)}))
171+
172+
173+
class AuthMiddleware(BaseHTTPMiddleware):
174+
async def dispatch(self, request, call_next):
175+
if request.url.path.startswith("/sse") or request.url.path.startswith("/messages"):
176+
api_key = request.headers.get("api_key")
177+
if not api_key:
178+
return JSONResponse({"error": "Missing unauthorization header"}, status_code=401)
179+
return await call_next(request)
180+
181+
182+
starlette_app = Starlette(
183+
debug=True,
184+
routes=[
185+
Route("/sse", endpoint=handle_sse),
186+
Mount("/messages/", app=sse.handle_post_message),
187+
],
188+
middleware=[Middleware(AuthMiddleware)],
189+
)
190+
191+
192+
if __name__ == "__main__":
193+
"""
194+
Launch example:
195+
uv run mcp/server/server.py --host=127.0.0.1 --port=9382 --base_url=http://127.0.0.1:9380
196+
"""
197+
198+
import argparse
199+
import os
200+
201+
import uvicorn
202+
from dotenv import load_dotenv
203+
204+
load_dotenv()
205+
206+
parser = argparse.ArgumentParser(description="RAGFlow MCP Server, `base_url` and `api_key` are needed.")
207+
parser.add_argument("--base_url", type=str, default="http://127.0.0.1:9380", help="api_url: http://<host_address>")
208+
parser.add_argument("--host", type=str, default="127.0.0.1", help="RAGFlow MCP SERVER host")
209+
parser.add_argument("--port", type=str, default="9382", help="RAGFlow MCP SERVER port")
210+
args = parser.parse_args()
211+
212+
BASE_URL = os.environ.get("RAGFLOW_MCP_BASE_URL", args.base_url)
213+
HOST = os.environ.get("RAGFLOW_MCP_HOST", args.host)
214+
PORT = os.environ.get("RAGFLOW_MCP_PORT", args.port)
215+
216+
print(
217+
r"""
218+
__ __ ____ ____ ____ _____ ______ _______ ____
219+
| \/ |/ ___| _ \ / ___|| ____| _ \ \ / / ____| _ \
220+
| |\/| | | | |_) | \___ \| _| | |_) \ \ / /| _| | |_) |
221+
| | | | |___| __/ ___) | |___| _ < \ V / | |___| _ <
222+
|_| |_|\____|_| |____/|_____|_| \_\ \_/ |_____|_| \_\
223+
""",
224+
flush=True,
225+
)
226+
print(f"MCP host: {HOST}", flush=True)
227+
print(f"MCP port: {PORT}", flush=True)
228+
print(f"MCP base_url: {BASE_URL}", flush=True)
229+
230+
uvicorn.run(
231+
starlette_app,
232+
host=HOST,
233+
port=int(PORT),
234+
)

0 commit comments

Comments
 (0)