Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions python/mcp-ollama-rag/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ components:
Function)
- **MCPServer Class**: FastMCP-based server implementing HTTP-streamable MCP
protocol
- **MCP Tools**: Three primary tools for Ollama interaction:
- **MCP Tools**: Four tools for Ollama interaction:
- `list_models`: Enumerate available models on the Ollama server
- `pull_model`: Download and install new models
- `call_model`: Send prompts to models and receive responses
- `rag_document`: RAG a document - accepts urls or text (strings)
- `embed_document`: Embed documents for RAG - accepts URLs or text strings,
automatically chunks to fit the embedding model's context window
- `call_model`: Send prompts to models with RAG context and receive responses


## Setup
Expand All @@ -42,7 +43,7 @@ protocol
```bash

# optionally setup venv
pythom -m venv venv
python -m venv venv
source venv/bin/activate

# and install deps
Expand All @@ -57,8 +58,9 @@ protocol
# Start Ollama service (in different terminal/ in bg)
ollama serve

# Pull a model (optional, can be done via MCP tool)
# Pull models
ollama pull llama3.2:3b
ollama pull mxbai-embed-large
```

Now you have a running Ollama Server
Expand All @@ -83,7 +85,7 @@ ollama server and call the (now) specialized inference model with prompts.

Now you've connected via MCP protocol to the running function, using an MCP client
which has embedded a document into vector space for RAG tooling and prompted the
model which can use the embeddings to answer your question (hopefuly) in a more
model which can use the embeddings to answer your question (hopefully) in a more
sophisticated manner.

### Deployment to cluster (not tested)
Expand All @@ -109,3 +111,6 @@ or portforwarding etc.
- Verify model availability with `ollama list`
- Confirm function is running on expected port (default: 8080)

**HuggingFace rate limits:**
The tokenizer for document chunking is downloaded from HuggingFace Hub on first
use. If you hit rate limits, log in with `huggingface-cli login`.
22 changes: 8 additions & 14 deletions python/mcp-ollama-rag/client/client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
import asyncio
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
import json

from mcp.types import CallToolResult

def unload_list_models(models: CallToolResult) -> list[str]:
return [json.loads(item.text)["model"] for item in models.content if item.text.strip().startswith('{')] #pyright: ignore

async def main():
# check your running Function MCP Server, it will output where its available
Expand All @@ -15,12 +9,12 @@ async def main():
read_stream,write_stream = streams[0],streams[1]

async with ClientSession(read_stream,write_stream) as sess:
print("Initializing connection...",end="")
await sess.initialize()
print("Initializing connection...", end="", flush=True)
_ = await sess.initialize()
print("done!\n")


# embed some documents
print("Embedding documents (this may take a moment)...", flush=True)
embed = await sess.call_tool(
name="embed_document",
arguments={
Expand All @@ -30,17 +24,17 @@ async def main():
],
}
)
print(embed.content[0].text) # pyright: ignore[reportAttributeAccessIssue]
print(embed.content[0].text) # pyright: ignore[reportAttributeAccessIssue]
print("-"*60)

# prompt the inference model
prompt = "What actually is a Knative Function?"
print(f"Querying: \"{prompt}\"", flush=True)
resp = await sess.call_tool(
name="call_model",
arguments={
"prompt": "What actually is a Knative Function?",
}
arguments={"prompt": prompt},
)
print(resp.content[0].text)
print(resp.content[0].text) # pyright: ignore[reportAttributeAccessIssue]

if __name__ == "__main__":
asyncio.run(main())
187 changes: 94 additions & 93 deletions python/mcp-ollama-rag/function/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,25 @@

# Function as an MCP Server implementation
import logging
import uuid

from mcp.server.fastmcp import FastMCP
import ollama
import asyncio
import chromadb
import requests

from .parser import resolve_input, chunk_text

# Silence noisy library loggers
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("mcp").setLevel(logging.WARNING)

def new():
""" New is the only method that must be implemented by a Function.
"""New is the only method that must be implemented by a Function.
The instance returned can be of any name.
"""
return Function()

# Accepts any url link which points to a raw data (*.md/text files etc.)
# example: https://raw.githubusercontent.com/knative/func/main/docs/function-templates/python.md
def get_raw_content(url: str) -> str:
""" retrieve contents of github raw url as a text """
response = requests.get(url)
response.raise_for_status() # errors if bad response
print(f"fetch '{url}' - ok")
return response.text

class MCPServer:
"""
MCP server that exposes a chat with an LLM model running on Ollama server
Expand All @@ -39,69 +36,67 @@ def __init__(self):

self.client = ollama.Client()

#init database stuff
# init vector database
self.dbClient = chromadb.Client()
self.collection = self.dbClient.create_collection(name="my_collection")
self.collection = self.dbClient.get_or_create_collection(
name="my_collection"
)
# default embedding model
self.embedding_model = "mxbai-embed-large"
# call this after self.embedding_model assignment, so its defined
self._register_tools()

def _register_tools(self):
"""Register MCP tools."""

@self.mcp.tool()
def list_models():
"""List all models currently available on the Ollama server"""
try:
models = self.client.list()
except Exception as e:
return f"Oops, failed to list models because: {str(e)}"
#return [model['name'] for model in models['models']]
return [model for model in models]

default_embedding_model = self.embedding_model

@self.mcp.tool()
def embed_document(data:list[str],model:str = default_embedding_model) -> str:
def embed_document(
data: list[str], model: str = default_embedding_model
) -> str:
"""
RAG (Retrieval-augmented generation) tool.
Embeds documents provided in data.
Arguments:
- data: expected to be of type str|list.
- model: embedding model to use, examples below.

# example embedding models:
# mxbai-embed-large - 334M *default
# nomic-embed-text - 137M
# all-minilm - 23M
Embeds documents provided in data. Each item can be a URL
(fetched automatically) or a raw text string. Documents are
chunked to fit the embedding model's context window.

Args:
data: List of URLs or text strings to embed.
model: Embedding model to use. Example:
- mxbai-embed-large - default
"""
count = 0

############ TODO -- import im a separate file
# documents generator
#documents_gen = parse_data_generator(data)
#### 1) GENERATE
# generate vector embeddings via embedding model
#for i, d in enumerate(documents_gen):
# response = ollama.embed(model=model,input=d)
# embeddings = response["embeddings"]
# self.collection.add(
# ids=[str(i)],
# embeddings=embeddings,
# documents=[d]
# )
# count += 1

# for simplicity (until the above is resolved, this accecpts only URLs)
for i, d in enumerate(data):
response = ollama.embed(model=model,input=get_raw_content(d))
embeddings = response["embeddings"]
self.collection.add(
ids=[str(i)],
embeddings=embeddings,
documents=[d]
)
count += 1
return f"ok - Embedded {count} documents"
all_chunks = []
for item in data:
content = resolve_input(item)
chunks = chunk_text(content)
all_chunks.extend(chunks)
label = item[:60] + "..." if len(item) > 60 else item
print(f" Chunked '{label}' into {len(chunks)} chunks", flush=True)

# Batch embed all chunks in one call for performance
print(f" Embedding {len(all_chunks)} chunks...", flush=True)
response = ollama.embed(model=model, input=all_chunks)
ids = [str(uuid.uuid4()) for _ in all_chunks]
self.collection.add(
ids=ids,
embeddings=response["embeddings"],
documents=all_chunks,
)
print(" Done.", flush=True)

return (
f"ok - Embedded {len(data)} document(s) "
f"as {len(all_chunks)} chunks"
)

@self.mcp.tool()
def pull_model(model: str) -> str:
Expand All @@ -113,42 +108,42 @@ def pull_model(model: str) -> str:
return f"Success! model {model} is available"

@self.mcp.tool()
def call_model(prompt: str,
model: str = "llama3.2:3b",
embed_model: str = self.embedding_model) -> str:
"""Send a prompt to a model being served on ollama server"""
#### 2) RETRIEVE
# we embed the prompt but dont save it into db, then we retrieve
# the most relevant document (most similar vectors)
def call_model(
prompt: str,
model: str = "llama3.2:3b",
embed_model: str = self.embedding_model,
) -> str:
"""Send a prompt to a model being served on ollama server.
Uses RAG to find the most relevant embedded documents and
includes them as context for the response."""
try:
response = ollama.embed(
model=embed_model,
input=prompt
)
# Embed the prompt for similarity search
response = ollama.embed(model=embed_model, input=prompt)
results = self.collection.query(
query_embeddings=response["embeddings"],
n_results=1
)
data = results['documents'][0][0]
query_embeddings=response["embeddings"],
n_results=3,
)
context = "\n\n".join(results["documents"][0])

#### 3) GENERATE
# generate answer given a combination of prompt and data retrieved
output = ollama.generate(
model=model,
prompt=f'Using data: {data}, respond to prompt: {prompt}'
)
print(output)
model=model,
prompt=(
f"Using the following context:\n{context}\n\n"
f"Respond to: {prompt}"
),
)
except Exception as e:
return f"Error occurred during calling the model: {str(e)}"
return output['response']
return output["response"]

async def handle(self, scope, receive, send):
"""Handle ASGI requests - both lifespan and HTTP."""
await self._app(scope, receive, send)


class Function:
def __init__(self):
""" The init method is an optional method where initialization can be
"""The init method is an optional method where initialization can be
performed. See the start method for a startup hook which includes
configuration.
"""
Expand All @@ -166,7 +161,7 @@ async def handle(self, scope, receive, send):
await self._initialize_mcp()

# Route MCP requests
if scope['path'].startswith('/mcp'):
if scope.get("path", "").startswith("/mcp"):
await self.mcp_server.handle(scope, receive, send)
return

Expand All @@ -175,26 +170,28 @@ async def handle(self, scope, receive, send):

async def _initialize_mcp(self):
"""Initialize the MCP server by sending lifespan startup event."""
lifespan_scope = {'type': 'lifespan', 'asgi': {'version': '3.0'}}
lifespan_scope = {"type": "lifespan", "asgi": {"version": "3.0"}}
startup_sent = False

async def lifespan_receive():
nonlocal startup_sent
if not startup_sent:
startup_sent = True
return {'type': 'lifespan.startup'}
return {"type": "lifespan.startup"}
await asyncio.Event().wait() # Wait forever for shutdown

async def lifespan_send(message):
if message['type'] == 'lifespan.startup.complete':
if message["type"] == "lifespan.startup.complete":
self._mcp_initialized = True
elif message['type'] == 'lifespan.startup.failed':
elif message["type"] == "lifespan.startup.failed":
logging.error(f"MCP startup failed: {message}")

# Start lifespan in background
asyncio.create_task(self.mcp_server.handle(
lifespan_scope, lifespan_receive, lifespan_send
))
asyncio.create_task(
self.mcp_server.handle(
lifespan_scope, lifespan_receive, lifespan_send
)
)

# Brief wait for startup completion
await asyncio.sleep(0.1)
Expand All @@ -204,15 +201,19 @@ async def _send_default_response(self, send):
Send default OK response.
This is for your non MCP requests if desired.
"""
await send({
'type': 'http.response.start',
'status': 200,
'headers': [[b'content-type', b'text/plain']],
})
await send({
'type': 'http.response.body',
'body': b'OK',
})
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [[b"content-type", b"text/plain"]],
}
)
await send(
{
"type": "http.response.body",
"body": b"OK",
}
)

def start(self, cfg):
logging.info("Function starting")
Expand Down
Loading
Loading