Step 1: Set up your environment
First, ensure you have Python installed (preferably 3.8+) and set up a virtual environment:
python -m venv mcp-env
source mcp-env/bin/activate #
#On Windows, use: cd mcp-env\Scripts\activate
Step 2: Install dependencies
Install the required packages:
bash
pip install fastapi uvicorn torch transformers pydantic
Step 3: Create the server code
Create a Python file for your MCP server implementation:
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Dict, Any, Optional
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("mcp-server")
# Initialize FastAPI app
app = FastAPI(title="MCP Server")
# Model configuration
class ModelConfig(BaseModel):
model_id: str
device: str = "cuda" if torch.cuda.is_available() else "cpu"
max_length: int = 2048
temperature: float = 0.7
top_p: float = 0.9
hf_token: Optional[str] = None # Token for accessing gated models
# Inference request
class InferenceRequest(BaseModel):
prompt: str
max_new_tokens: Optional[int] = 256
temperature: Optional[float] = None
top_p: Optional[float] = None
stop_sequences: Optional[List[str]] = None
# Inference response
class InferenceResponse(BaseModel):
generated_text: str
usage: Dict[str, int]
# Global model cache
model_cache = {}
@app.post("/load_model")
async def load_model(config: ModelConfig):
"""Load a model into memory"""
model_id = config.model_id
if model_id in model_cache:
return {"status": "Model already loaded", "model_id": model_id}
try:
logger.info(f"Loading model {model_id} on {config.device}")
tokenizer = AutoTokenizer.from_pretrained(model_id, token=config.hf_token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16 if config.device == "cuda" else torch.float32,
device_map=config.device,
token=config.hf_token
)
model_cache[model_id] = {
"model": model,
"tokenizer": tokenizer,
"config": config
}
return {"status": "Model loaded successfully", "model_id": model_id}
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to load model: {str(e)}")
@app.post("/generate", response_model=InferenceResponse)
async def generate_text(request: InferenceRequest, model_id: str):
"""Generate text using the specified model"""
if model_id not in model_cache:
raise HTTPException(status_code=404, detail=f"Model {model_id} not loaded")
cache_entry = model_cache[model_id]
model = cache_entry["model"]
tokenizer = cache_entry["tokenizer"]
config = cache_entry["config"]
# Apply request parameters or use defaults from model config
temperature = request.temperature if request.temperature is not None else config.temperature
top_p = request.top_p if request.top_p is not None else config.top_p
max_new_tokens = request.max_new_tokens
try:
input_ids = tokenizer.encode(request.prompt, return_tensors="pt").to(config.device)
input_token_count = input_ids.shape[1]
# Generate text
with torch.no_grad():
output = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=temperature > 0,
pad_token_id=tokenizer.eos_token_id
)
# Decode output
generated_text = tokenizer.decode(output[0][input_token_count:], skip_special_tokens=True)
# Handle stop sequences
if request.stop_sequences:
for stop_seq in request.stop_sequences:
if stop_seq in generated_text:
generated_text = generated_text[:generated_text.find(stop_seq)]
# Calculate token usage
total_tokens = output.shape[1]
new_tokens = total_tokens - input_token_count
return InferenceResponse(
generated_text=generated_text,
usage={
"prompt_tokens": input_token_count,
"completion_tokens": new_tokens,
"total_tokens": total_tokens
}
)
except Exception as e:
logger.error(f"Error during generation: {str(e)}")
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
@app.post("/unload_model")
async def unload_model(model_id: str):
"""Unload a model from memory"""
if model_id not in model_cache:
raise HTTPException(status_code=404, detail=f"Model {model_id} not loaded")
try:
# Remove model from cache
del model_cache[model_id]
# Force garbage collection
import gc
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
return {"status": "Model unloaded successfully", "model_id": model_id}
except Exception as e:
logger.error(f"Error unloading model: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to unload model: {str(e)}")
@app.get("/models")
async def list_models():
"""List all loaded models"""
return {
"models": [
{
"model_id": model_id,
"device": cache["config"].device,
"max_length": cache["config"].max_length
}
for model_id, cache in model_cache.items()
]
}
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"loaded_models": len(model_cache),
"cuda_available": torch.cuda.is_available(),
"cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0
}
@app.get("/")
async def root():
"""Root endpoint"""
return {"status": "MCP Server is running", "version": "1.0.0"}
if __name__ == "__main__":
import uvicorn
uvicorn.run("mcp_server:app", host="0.0.0.0", port=8000, reload=True)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Step 4: Run the server
Save the above code to a file named mcp_server.py and run:
python mcp_server.py
Or directly with uvicorn:
bash
uvicorn mcp_server:app --host 0.0.0.0 --port 8000
http://localhost:8000/docs#/default/load_model_load_model_post
response_body for /load_model
{
"model_id": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"device": "cpu",
"max_length": 2048,
"temperature": 0.7,
"top_p": 0.9,
"hf_token": "hf_XXXXXXXXX"
}