|
| 1 | +# |
| 2 | +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +# |
| 16 | + |
| 17 | +import json |
| 18 | +from collections.abc import AsyncIterator |
| 19 | +from contextlib import asynccontextmanager |
| 20 | + |
| 21 | +import requests |
| 22 | +from starlette.applications import Starlette |
| 23 | +from starlette.middleware import Middleware |
| 24 | +from starlette.middleware.base import BaseHTTPMiddleware |
| 25 | +from starlette.responses import JSONResponse |
| 26 | +from starlette.routing import Mount, Route |
| 27 | + |
| 28 | +import mcp.types as types |
| 29 | +from mcp.server.lowlevel import Server |
| 30 | +from mcp.server.sse import SseServerTransport |
| 31 | + |
| 32 | +BASE_URL = "http://127.0.0.1:9380" |
| 33 | +HOST = "127.0.0.1" |
| 34 | +PORT = "9382" |
| 35 | + |
| 36 | + |
| 37 | +class RAGFlowConnector: |
| 38 | + def __init__(self, base_url: str, version="v1"): |
| 39 | + self.base_url = base_url |
| 40 | + self.version = version |
| 41 | + self.api_url = f"{self.base_url}/api/{self.version}" |
| 42 | + |
| 43 | + def bind_api_key(self, api_key: str): |
| 44 | + self.api_key = api_key |
| 45 | + self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.api_key)} |
| 46 | + |
| 47 | + def _post(self, path, json=None, stream=False, files=None): |
| 48 | + if not self.api_key: |
| 49 | + return None |
| 50 | + res = requests.post(url=self.api_url + path, json=json, headers=self.authorization_header, stream=stream, files=files) |
| 51 | + return res |
| 52 | + |
| 53 | + def _get(self, path, params=None, json=None): |
| 54 | + res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header, json=json) |
| 55 | + return res |
| 56 | + |
| 57 | + def list_datasets(self, page: int = 1, page_size: int = 1000, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None): |
| 58 | + res = self._get("/datasets", {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name}) |
| 59 | + if not res: |
| 60 | + raise Exception([types.TextContent(type="text", text=res.get("Cannot process this operation."))]) |
| 61 | + |
| 62 | + res = res.json() |
| 63 | + if res.get("code") == 0: |
| 64 | + result_list = [] |
| 65 | + for data in res["data"]: |
| 66 | + d = {"description": data["description"], "id": data["id"]} |
| 67 | + result_list.append(json.dumps(d, ensure_ascii=False)) |
| 68 | + return "\n".join(result_list) |
| 69 | + return "" |
| 70 | + |
| 71 | + def retrival( |
| 72 | + self, dataset_ids, document_ids=None, question="", page=1, page_size=30, similarity_threshold=0.2, vector_similarity_weight=0.3, top_k=1024, rerank_id: str | None = None, keyword: bool = False |
| 73 | + ): |
| 74 | + if document_ids is None: |
| 75 | + document_ids = [] |
| 76 | + data_json = { |
| 77 | + "page": page, |
| 78 | + "page_size": page_size, |
| 79 | + "similarity_threshold": similarity_threshold, |
| 80 | + "vector_similarity_weight": vector_similarity_weight, |
| 81 | + "top_k": top_k, |
| 82 | + "rerank_id": rerank_id, |
| 83 | + "keyword": keyword, |
| 84 | + "question": question, |
| 85 | + "dataset_ids": dataset_ids, |
| 86 | + "document_ids": document_ids, |
| 87 | + } |
| 88 | + # Send a POST request to the backend service (using requests library as an example, actual implementation may vary) |
| 89 | + res = self._post("/retrieval", json=data_json) |
| 90 | + if not res: |
| 91 | + raise Exception([types.TextContent(type="text", text=res.get("Cannot process this operation."))]) |
| 92 | + |
| 93 | + res = res.json() |
| 94 | + if res.get("code") == 0: |
| 95 | + chunks = [] |
| 96 | + for chunk_data in res["data"].get("chunks"): |
| 97 | + chunks.append(json.dumps(chunk_data, ensure_ascii=False)) |
| 98 | + return [types.TextContent(type="text", text="\n".join(chunks))] |
| 99 | + raise Exception([types.TextContent(type="text", text=res.get("message"))]) |
| 100 | + |
| 101 | + |
| 102 | +class RAGFlowCtx: |
| 103 | + def __init__(self, connector: RAGFlowConnector): |
| 104 | + self.conn = connector |
| 105 | + |
| 106 | + |
| 107 | +@asynccontextmanager |
| 108 | +async def server_lifespan(server: Server) -> AsyncIterator[dict]: |
| 109 | + ctx = RAGFlowCtx(RAGFlowConnector(base_url=BASE_URL)) |
| 110 | + |
| 111 | + try: |
| 112 | + yield {"ragflow_ctx": ctx} |
| 113 | + finally: |
| 114 | + pass |
| 115 | + |
| 116 | + |
| 117 | +app = Server("ragflow-server", lifespan=server_lifespan) |
| 118 | +sse = SseServerTransport("/messages/") |
| 119 | + |
| 120 | + |
| 121 | +@app.list_tools() |
| 122 | +async def list_tools() -> list[types.Tool]: |
| 123 | + ctx = app.request_context |
| 124 | + ragflow_ctx = ctx.lifespan_context["ragflow_ctx"] |
| 125 | + if not ragflow_ctx: |
| 126 | + raise ValueError("Get RAGFlow Context failed") |
| 127 | + connector = ragflow_ctx.conn |
| 128 | + |
| 129 | + api_key = ctx.session._init_options.capabilities.experimental["headers"]["api_key"] |
| 130 | + if not api_key: |
| 131 | + raise ValueError("RAGFlow API_KEY is required.") |
| 132 | + connector.bind_api_key(api_key) |
| 133 | + |
| 134 | + dataset_description = connector.list_datasets() |
| 135 | + |
| 136 | + return [ |
| 137 | + types.Tool( |
| 138 | + name="retrival", |
| 139 | + description="Retrieve relevant chunks from the RAGFlow retrieve interface based on the question, using the specified dataset_ids and optionally document_ids. Below is the list of all available datasets, including their descriptions and IDs. If you're unsure which datasets are relevant to the question, simply pass all dataset IDs to the function." |
| 140 | + + dataset_description, |
| 141 | + inputSchema={ |
| 142 | + "type": "object", |
| 143 | + "properties": {"dataset_ids": {"type": "array", "items": {"type": "string"}}, "documents_ids": {"type": "array", "items": {"type": "string"}}, "question": {"type": "string"}}, |
| 144 | + "required": ["dataset_ids", "question"], |
| 145 | + }, |
| 146 | + ), |
| 147 | + ] |
| 148 | + |
| 149 | + |
| 150 | +@app.call_tool() |
| 151 | +async def call_tool(name: str, arguments: dict) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: |
| 152 | + ctx = app.request_context |
| 153 | + ragflow_ctx = ctx.lifespan_context["ragflow_ctx"] |
| 154 | + if not ragflow_ctx: |
| 155 | + raise ValueError("Get RAGFlow Context failed") |
| 156 | + connector = ragflow_ctx.conn |
| 157 | + |
| 158 | + api_key = ctx.session._init_options.capabilities.experimental["headers"]["api_key"] |
| 159 | + if not api_key: |
| 160 | + raise ValueError("RAGFlow API_KEY is required.") |
| 161 | + connector.bind_api_key(api_key) |
| 162 | + |
| 163 | + if name == "ragflow_retrival": |
| 164 | + return connector.retrival(dataset_ids=arguments["dataset_ids"], document_ids=arguments["document_ids"], question=arguments["question"]) |
| 165 | + raise ValueError(f"Tool not found: {name}") |
| 166 | + |
| 167 | + |
| 168 | +async def handle_sse(request): |
| 169 | + async with sse.connect_sse(request.scope, request.receive, request._send) as streams: |
| 170 | + await app.run(streams[0], streams[1], app.create_initialization_options(experimental_capabilities={"headers": dict(request.headers)})) |
| 171 | + |
| 172 | + |
| 173 | +class AuthMiddleware(BaseHTTPMiddleware): |
| 174 | + async def dispatch(self, request, call_next): |
| 175 | + if request.url.path.startswith("/sse") or request.url.path.startswith("/messages"): |
| 176 | + api_key = request.headers.get("api_key") |
| 177 | + if not api_key: |
| 178 | + return JSONResponse({"error": "Missing unauthorization header"}, status_code=401) |
| 179 | + return await call_next(request) |
| 180 | + |
| 181 | + |
| 182 | +starlette_app = Starlette( |
| 183 | + debug=True, |
| 184 | + routes=[ |
| 185 | + Route("/sse", endpoint=handle_sse), |
| 186 | + Mount("/messages/", app=sse.handle_post_message), |
| 187 | + ], |
| 188 | + middleware=[Middleware(AuthMiddleware)], |
| 189 | +) |
| 190 | + |
| 191 | + |
| 192 | +if __name__ == "__main__": |
| 193 | + """ |
| 194 | + Launch example: |
| 195 | + uv run mcp/server/server.py --host=127.0.0.1 --port=9382 --base_url=http://127.0.0.1:9380 |
| 196 | + """ |
| 197 | + |
| 198 | + import argparse |
| 199 | + import os |
| 200 | + |
| 201 | + import uvicorn |
| 202 | + from dotenv import load_dotenv |
| 203 | + |
| 204 | + load_dotenv() |
| 205 | + |
| 206 | + parser = argparse.ArgumentParser(description="RAGFlow MCP Server, `base_url` and `api_key` are needed.") |
| 207 | + parser.add_argument("--base_url", type=str, default="http://127.0.0.1:9380", help="api_url: http://<host_address>") |
| 208 | + parser.add_argument("--host", type=str, default="127.0.0.1", help="RAGFlow MCP SERVER host") |
| 209 | + parser.add_argument("--port", type=str, default="9382", help="RAGFlow MCP SERVER port") |
| 210 | + args = parser.parse_args() |
| 211 | + |
| 212 | + BASE_URL = os.environ.get("RAGFLOW_MCP_BASE_URL", args.base_url) |
| 213 | + HOST = os.environ.get("RAGFLOW_MCP_HOST", args.host) |
| 214 | + PORT = os.environ.get("RAGFLOW_MCP_PORT", args.port) |
| 215 | + |
| 216 | + print( |
| 217 | + r""" |
| 218 | +__ __ ____ ____ ____ _____ ______ _______ ____ |
| 219 | +| \/ |/ ___| _ \ / ___|| ____| _ \ \ / / ____| _ \ |
| 220 | +| |\/| | | | |_) | \___ \| _| | |_) \ \ / /| _| | |_) | |
| 221 | +| | | | |___| __/ ___) | |___| _ < \ V / | |___| _ < |
| 222 | +|_| |_|\____|_| |____/|_____|_| \_\ \_/ |_____|_| \_\ |
| 223 | + """, |
| 224 | + flush=True, |
| 225 | + ) |
| 226 | + print(f"MCP host: {HOST}", flush=True) |
| 227 | + print(f"MCP port: {PORT}", flush=True) |
| 228 | + print(f"MCP base_url: {BASE_URL}", flush=True) |
| 229 | + |
| 230 | + uvicorn.run( |
| 231 | + starlette_app, |
| 232 | + host=HOST, |
| 233 | + port=int(PORT), |
| 234 | + ) |
0 commit comments