fix(mcp): fix indentation error in mcp_integration.py

pull/819/head
Pavan Kumar 3 months ago committed by ascender1729
parent 1a84b24394
commit 16a2321525

@ -1,22 +1,22 @@
from __future__ import annotations from __future__ import annotations
import abc import abc
import asyncio import asyncio
from contextlib import AbstractAsyncContextManager, AsyncExitStack from contextlib import AbstractAsyncContextManager, AsyncExitStack
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Literal from typing import Any, Dict, List, Optional, Literal
from typing_extensions import NotRequired, TypedDict from typing_extensions import NotRequired, TypedDict
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from loguru import logger from loguru import logger
from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client
from mcp.client.sse import sse_client from mcp.client.sse import sse_client
from mcp.types import CallToolResult, JSONRPCMessage from mcp.types import CallToolResult, JSONRPCMessage
from swarms.utils.any_to_str import any_to_str from swarms.utils.any_to_str import any_to_str
class MCPServer(abc.ABC): class MCPServer(abc.ABC):
"""Base class for Model Context Protocol servers.""" """Base class for Model Context Protocol servers."""
@abc.abstractmethod @abc.abstractmethod
@ -48,7 +48,7 @@
pass pass
class _MCPServerWithClientSession(MCPServer, abc.ABC): class _MCPServerWithClientSession(MCPServer, abc.ABC):
"""Mixin providing ClientSession-based MCP communication.""" """Mixin providing ClientSession-based MCP communication."""
def __init__(self, cache_tools_list: bool = False): def __init__(self, cache_tools_list: bool = False):
@ -127,7 +127,7 @@
return await self.session.call_tool(name, arguments) return await self.session.call_tool(name, arguments)
class MCPServerStdioParams(TypedDict): class MCPServerStdioParams(TypedDict):
"""Configuration for stdio transport.""" """Configuration for stdio transport."""
command: str command: str
args: NotRequired[List[str]] args: NotRequired[List[str]]
@ -137,7 +137,7 @@
encoding_error_handler: NotRequired[Literal["strict", "ignore", "replace"]] encoding_error_handler: NotRequired[Literal["strict", "ignore", "replace"]]
class MCPServerStdio(_MCPServerWithClientSession): class MCPServerStdio(_MCPServerWithClientSession):
"""MCP server over stdio transport.""" """MCP server over stdio transport."""
def __init__( def __init__(
@ -170,7 +170,7 @@
return self._name return self._name
class MCPServerSseParams(TypedDict): class MCPServerSseParams(TypedDict):
"""Configuration for HTTP+SSE transport.""" """Configuration for HTTP+SSE transport."""
url: str url: str
headers: NotRequired[Dict[str, str]] headers: NotRequired[Dict[str, str]]
@ -178,7 +178,7 @@
sse_read_timeout: NotRequired[float] sse_read_timeout: NotRequired[float]
class MCPServerSse(_MCPServerWithClientSession): class MCPServerSse(_MCPServerWithClientSession):
"""MCP server over HTTP with SSE transport.""" """MCP server over HTTP with SSE transport."""
def __init__( def __init__(
@ -209,9 +209,9 @@
return self._name return self._name
async def call_tool_fast( async def call_tool_fast(
server: MCPServerSse, payload: Dict[str, Any] | str server: MCPServerSse, payload: Dict[str, Any] | str
) -> Any: ) -> Any:
try: try:
await server.connect() await server.connect()
result = await server.call_tool(arguments=payload if isinstance(payload, dict) else None) result = await server.call_tool(arguments=payload if isinstance(payload, dict) else None)
@ -220,25 +220,25 @@
await server.cleanup() await server.cleanup()
async def mcp_flow_get_tool_schema( async def mcp_flow_get_tool_schema(
params: MCPServerSseParams, params: MCPServerSseParams,
) -> Any: ) -> Any:
async with MCPServerSse(params) as server: async with MCPServerSse(params) as server:
tools = await server.list_tools() tools = await server.list_tools()
return tools return tools
async def mcp_flow( async def mcp_flow(
params: MCPServerSseParams, params: MCPServerSseParams,
function_call: Dict[str, Any] | str, function_call: Dict[str, Any] | str,
) -> Any: ) -> Any:
async with MCPServerSse(params) as server: async with MCPServerSse(params) as server:
return await call_tool_fast(server, function_call) return await call_tool_fast(server, function_call)
async def _call_one_server( async def _call_one_server(
params: MCPServerSseParams, payload: Dict[str, Any] | str params: MCPServerSseParams, payload: Dict[str, Any] | str
) -> Any: ) -> Any:
server = MCPServerSse(params) server = MCPServerSse(params)
try: try:
await server.connect() await server.connect()
@ -247,9 +247,9 @@
await server.cleanup() await server.cleanup()
def batch_mcp_flow( def batch_mcp_flow(
params: List[MCPServerSseParams], payload: Dict[str, Any] | str params: List[MCPServerSseParams], payload: Dict[str, Any] | str
) -> List[Any]: ) -> List[Any]:
return asyncio.run( return asyncio.run(
asyncio.gather(*[_call_one_server(p, payload) for p in params]) asyncio.gather(*[_call_one_server(p, payload) for p in params])
) )
Loading…
Cancel
Save