feat: ShootTracker SQLite+JWT+YOLOv8
This commit is contained in:
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
ShootTracker — Microservice IA
|
||||
FastAPI + YOLOv8 + OpenCV
|
||||
|
||||
Endpoints :
|
||||
POST /analyze — Analyse d'une photo de cible
|
||||
GET /health — Santé du service
|
||||
GET / — Infos
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from analyzer import TargetAnalyzer, AnalysisResult
|
||||
|
||||
# ─── Logging ──────────────────────────────────────────────────────────────────
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s [%(levelname)s] %(name)s: %(message)s',
|
||||
datefmt='%H:%M:%S'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ─── Modèle global ────────────────────────────────────────────────────────────
|
||||
analyzer: Optional[TargetAnalyzer] = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Charge le modèle au démarrage."""
|
||||
global analyzer
|
||||
model_path = os.getenv('YOLO_MODEL_PATH', 'yolov8n.pt')
|
||||
logger.info(f"Chargement du modèle YOLOv8 : {model_path}")
|
||||
analyzer = TargetAnalyzer(model_path=model_path)
|
||||
logger.info("✓ Service IA prêt")
|
||||
yield
|
||||
logger.info("Service IA arrêté")
|
||||
|
||||
# ─── Application FastAPI ──────────────────────────────────────────────────────
|
||||
app = FastAPI(
|
||||
title="ShootTracker AI Service",
|
||||
description="Analyse d'impacts de tir avec YOLOv8 + OpenCV",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# ─── Schémas Pydantic ─────────────────────────────────────────────────────────
|
||||
|
||||
class AnalyzeRequest(BaseModel):
|
||||
image_base64: str
|
||||
target_type: str
|
||||
previous_image_base64: Optional[str] = None
|
||||
|
||||
@field_validator('target_type')
|
||||
@classmethod
|
||||
def validate_target_type(cls, v: str) -> str:
|
||||
allowed = ['issf', 'silhouette', 'libre']
|
||||
if v not in allowed:
|
||||
raise ValueError(f"target_type doit être parmi : {allowed}")
|
||||
return v
|
||||
|
||||
@field_validator('image_base64')
|
||||
@classmethod
|
||||
def validate_image(cls, v: str) -> str:
|
||||
if not v or len(v) < 100:
|
||||
raise ValueError("image_base64 invalide ou trop courte")
|
||||
return v
|
||||
|
||||
class ImpactOut(BaseModel):
|
||||
x: float
|
||||
y: float
|
||||
zone: Optional[int]
|
||||
points: int
|
||||
confidence: float
|
||||
|
||||
class AnalyzeResponse(BaseModel):
|
||||
impacts: list[ImpactOut]
|
||||
total_detected: int
|
||||
score_zone: int
|
||||
score_groupement: int
|
||||
score_total: int
|
||||
dispersion_radius: float
|
||||
center_x: float
|
||||
center_y: float
|
||||
annotated_image_base64: Optional[str]
|
||||
error: Optional[str]
|
||||
warning: Optional[str]
|
||||
|
||||
# ─── Endpoints ────────────────────────────────────────────────────────────────
|
||||
|
||||
@app.get("/", tags=["Info"])
|
||||
async def root():
|
||||
return {
|
||||
"service": "ShootTracker AI",
|
||||
"version": "1.0.0",
|
||||
"status": "running",
|
||||
"model": os.getenv('YOLO_MODEL_PATH', 'yolov8n.pt'),
|
||||
}
|
||||
|
||||
@app.get("/health", tags=["Info"])
|
||||
async def health():
|
||||
return {
|
||||
"status": "ok",
|
||||
"model_loaded": analyzer is not None and analyzer.model is not None,
|
||||
"service": "shoottracker-ai",
|
||||
}
|
||||
|
||||
@app.post("/analyze", response_model=AnalyzeResponse, tags=["Analyse"])
|
||||
async def analyze_target(req: AnalyzeRequest):
|
||||
"""
|
||||
Analyse une photo de cible de tir.
|
||||
|
||||
- Détecte les impacts avec YOLOv8
|
||||
- Analyse les zones selon le type de cible
|
||||
- Si image_précédente fournie : ne compte que les NOUVEAUX impacts
|
||||
- Calcule le score et le groupement
|
||||
- Retourne la photo annotée en base64
|
||||
"""
|
||||
if analyzer is None:
|
||||
raise HTTPException(status_code=503, detail="Service IA non initialisé")
|
||||
|
||||
logger.info(f"Analyse demandée : target_type={req.target_type}, prev={'oui' if req.previous_image_base64 else 'non'}")
|
||||
|
||||
try:
|
||||
result: AnalysisResult = analyzer.analyze(
|
||||
image_b64=req.image_base64,
|
||||
target_type=req.target_type,
|
||||
previous_image_b64=req.previous_image_base64,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Erreur inattendue: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Erreur interne : {str(e)}")
|
||||
|
||||
if result.error and not result.annotated_image_base64:
|
||||
# Erreur critique (image illisible)
|
||||
raise HTTPException(status_code=422, detail=result.error)
|
||||
|
||||
return AnalyzeResponse(
|
||||
impacts=[ImpactOut(**imp) for imp in result.impacts],
|
||||
total_detected=result.total_detected,
|
||||
score_zone=result.score_zone,
|
||||
score_groupement=result.score_groupement,
|
||||
score_total=result.score_total,
|
||||
dispersion_radius=result.dispersion_radius,
|
||||
center_x=result.center_x,
|
||||
center_y=result.center_y,
|
||||
annotated_image_base64=result.annotated_image_base64,
|
||||
error=result.error,
|
||||
warning=result.warning,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
port = int(os.getenv("PORT", 8000))
|
||||
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=False, workers=1)
|
||||
Reference in New Issue
Block a user