disclosure-bureau/infra/embed-service/app.py

148 lines
4 KiB
Python

"""BGE-M3 embedding + reranker microservice.
Self-hosted on the VPS, CPU-only. Loaded lazily on first request, kept warm
in memory thereafter. Two HuggingFace models share ~2.5 GB RAM:
- BGE-M3 (BAAI/bge-m3) — multilingual dense embedding, 1024-dim, 8k context
- BGE-Reranker-v2-M3 (BAAI/bge-reranker-v2-m3) — cross-encoder for reranking
Endpoints:
- POST /embed { texts: string[], normalize?: bool }
- POST /rerank { query: string, docs: string[] }
- GET /health
- GET /info
"""
from __future__ import annotations
import os
import time
from threading import Lock
from typing import List, Optional
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
EMBED_MODEL_NAME = os.getenv("EMBED_MODEL", "BAAI/bge-m3")
RERANK_MODEL_NAME = os.getenv("RERANK_MODEL", "BAAI/bge-reranker-v2-m3")
DEVICE = os.getenv("DEVICE", "cpu")
_embed_model = None
_rerank_model = None
_embed_lock = Lock()
_rerank_lock = Lock()
def get_embed_model():
global _embed_model
with _embed_lock:
if _embed_model is None:
from FlagEmbedding import BGEM3FlagModel
_embed_model = BGEM3FlagModel(EMBED_MODEL_NAME, use_fp16=False, device=DEVICE)
return _embed_model
def get_rerank_model():
global _rerank_model
with _rerank_lock:
if _rerank_model is None:
from FlagEmbedding import FlagReranker
_rerank_model = FlagReranker(RERANK_MODEL_NAME, use_fp16=False, device=DEVICE)
return _rerank_model
app = FastAPI(title="Disclosure Bureau Embed Service", version="0.1.0")
class EmbedRequest(BaseModel):
texts: List[str] = Field(..., min_items=1, max_items=512)
normalize: bool = True
class EmbedResponse(BaseModel):
model: str
dim: int
count: int
elapsed_ms: int
embeddings: List[List[float]]
class RerankRequest(BaseModel):
query: str
docs: List[str] = Field(..., min_items=1, max_items=200)
normalize: bool = True
class RerankResponse(BaseModel):
model: str
elapsed_ms: int
scores: List[float]
@app.get("/health")
def health():
return {
"status": "ok",
"embed_loaded": _embed_model is not None,
"rerank_loaded": _rerank_model is not None,
}
@app.get("/info")
def info():
return {
"embed_model": EMBED_MODEL_NAME,
"rerank_model": RERANK_MODEL_NAME,
"device": DEVICE,
"embed_dim": 1024,
}
@app.post("/embed", response_model=EmbedResponse)
def embed(req: EmbedRequest):
t0 = time.time()
try:
model = get_embed_model()
out = model.encode(
req.texts,
batch_size=min(len(req.texts), 16),
max_length=8192,
return_dense=True,
return_sparse=False,
return_colbert_vecs=False,
)
vectors = out["dense_vecs"]
if req.normalize:
import numpy as np
arr = np.asarray(vectors)
norms = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-12
arr = arr / norms
vectors = arr
return EmbedResponse(
model=EMBED_MODEL_NAME,
dim=len(vectors[0]),
count=len(vectors),
elapsed_ms=int((time.time() - t0) * 1000),
embeddings=[list(map(float, v)) for v in vectors],
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"embed failed: {e}")
@app.post("/rerank", response_model=RerankResponse)
def rerank(req: RerankRequest):
t0 = time.time()
try:
model = get_rerank_model()
pairs = [[req.query, d] for d in req.docs]
scores = model.compute_score(pairs, normalize=req.normalize)
if isinstance(scores, float):
scores = [scores]
return RerankResponse(
model=RERANK_MODEL_NAME,
elapsed_ms=int((time.time() - t0) * 1000),
scores=[float(s) for s in scores],
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"rerank failed: {e}")