148 lines
4 KiB
Python
148 lines
4 KiB
Python
"""BGE-M3 embedding + reranker microservice.
|
|
|
|
Self-hosted on the VPS, CPU-only. Loaded lazily on first request, kept warm
|
|
in memory thereafter. Two HuggingFace models share ~2.5 GB RAM:
|
|
- BGE-M3 (BAAI/bge-m3) — multilingual dense embedding, 1024-dim, 8k context
|
|
- BGE-Reranker-v2-M3 (BAAI/bge-reranker-v2-m3) — cross-encoder for reranking
|
|
|
|
Endpoints:
|
|
- POST /embed { texts: string[], normalize?: bool }
|
|
- POST /rerank { query: string, docs: string[] }
|
|
- GET /health
|
|
- GET /info
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import time
|
|
from threading import Lock
|
|
from typing import List, Optional
|
|
|
|
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel, Field
|
|
|
|
EMBED_MODEL_NAME = os.getenv("EMBED_MODEL", "BAAI/bge-m3")
|
|
RERANK_MODEL_NAME = os.getenv("RERANK_MODEL", "BAAI/bge-reranker-v2-m3")
|
|
DEVICE = os.getenv("DEVICE", "cpu")
|
|
|
|
_embed_model = None
|
|
_rerank_model = None
|
|
_embed_lock = Lock()
|
|
_rerank_lock = Lock()
|
|
|
|
|
|
def get_embed_model():
|
|
global _embed_model
|
|
with _embed_lock:
|
|
if _embed_model is None:
|
|
from FlagEmbedding import BGEM3FlagModel
|
|
|
|
_embed_model = BGEM3FlagModel(EMBED_MODEL_NAME, use_fp16=False, device=DEVICE)
|
|
return _embed_model
|
|
|
|
|
|
def get_rerank_model():
|
|
global _rerank_model
|
|
with _rerank_lock:
|
|
if _rerank_model is None:
|
|
from FlagEmbedding import FlagReranker
|
|
|
|
_rerank_model = FlagReranker(RERANK_MODEL_NAME, use_fp16=False, device=DEVICE)
|
|
return _rerank_model
|
|
|
|
|
|
app = FastAPI(title="Disclosure Bureau Embed Service", version="0.1.0")
|
|
|
|
|
|
class EmbedRequest(BaseModel):
|
|
texts: List[str] = Field(..., min_items=1, max_items=512)
|
|
normalize: bool = True
|
|
|
|
|
|
class EmbedResponse(BaseModel):
|
|
model: str
|
|
dim: int
|
|
count: int
|
|
elapsed_ms: int
|
|
embeddings: List[List[float]]
|
|
|
|
|
|
class RerankRequest(BaseModel):
|
|
query: str
|
|
docs: List[str] = Field(..., min_items=1, max_items=200)
|
|
normalize: bool = True
|
|
|
|
|
|
class RerankResponse(BaseModel):
|
|
model: str
|
|
elapsed_ms: int
|
|
scores: List[float]
|
|
|
|
|
|
@app.get("/health")
|
|
def health():
|
|
return {
|
|
"status": "ok",
|
|
"embed_loaded": _embed_model is not None,
|
|
"rerank_loaded": _rerank_model is not None,
|
|
}
|
|
|
|
|
|
@app.get("/info")
|
|
def info():
|
|
return {
|
|
"embed_model": EMBED_MODEL_NAME,
|
|
"rerank_model": RERANK_MODEL_NAME,
|
|
"device": DEVICE,
|
|
"embed_dim": 1024,
|
|
}
|
|
|
|
|
|
@app.post("/embed", response_model=EmbedResponse)
|
|
def embed(req: EmbedRequest):
|
|
t0 = time.time()
|
|
try:
|
|
model = get_embed_model()
|
|
out = model.encode(
|
|
req.texts,
|
|
batch_size=min(len(req.texts), 16),
|
|
max_length=8192,
|
|
return_dense=True,
|
|
return_sparse=False,
|
|
return_colbert_vecs=False,
|
|
)
|
|
vectors = out["dense_vecs"]
|
|
if req.normalize:
|
|
import numpy as np
|
|
|
|
arr = np.asarray(vectors)
|
|
norms = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-12
|
|
arr = arr / norms
|
|
vectors = arr
|
|
return EmbedResponse(
|
|
model=EMBED_MODEL_NAME,
|
|
dim=len(vectors[0]),
|
|
count=len(vectors),
|
|
elapsed_ms=int((time.time() - t0) * 1000),
|
|
embeddings=[list(map(float, v)) for v in vectors],
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"embed failed: {e}")
|
|
|
|
|
|
@app.post("/rerank", response_model=RerankResponse)
|
|
def rerank(req: RerankRequest):
|
|
t0 = time.time()
|
|
try:
|
|
model = get_rerank_model()
|
|
pairs = [[req.query, d] for d in req.docs]
|
|
scores = model.compute_score(pairs, normalize=req.normalize)
|
|
if isinstance(scores, float):
|
|
scores = [scores]
|
|
return RerankResponse(
|
|
model=RERANK_MODEL_NAME,
|
|
elapsed_ms=int((time.time() - t0) * 1000),
|
|
scores=[float(s) for s in scores],
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"rerank failed: {e}")
|