import sqlite3
import os
import pandas as pd
from adaptive_sdk.external import RewardServer, ValidatedRequest, Response, ServerInfo
from pydantic import BaseModel, Field
from typing import List, Dict, Any
class SQLMetadata(BaseModel):
ground_truth_results: List[Dict[str, Any]]
db_path: str
class SQLRewardServer(RewardServer[SQLMetadata]):
def __init__(self, db_base_path: str, port=8000, blocking=True, **kwargs):
self.db_base_path = db_base_path
super().__init__(port, SQLMetadata, blocking, **kwargs)
async def score(self, request: ValidatedRequest[SQLMetadata]) -> Response:
sql_query = request.turns[-1].content
if not sql_query.upper().startswith("SELECT"):
return Response(reward=-1.0, metadata={"status": "invalid_query"})
try:
conn = sqlite3.connect(os.path.join(self.db_base_path, request.metadata.db_path))
df_actual = pd.read_sql_query(sql_query, conn)
actual_results = df_actual.to_dict(orient='records')
match = self._results_match(actual_results, request.metadata.ground_truth_results)
return Response(
reward=float(match),
metadata={"status": "success" if match else "wrong_result"}
)
except Exception as e:
return Response(reward=-1.0, metadata={"status": "error", "message": str(e)})
finally:
if 'conn' in locals():
conn.close()
def _results_match(self, actual, expected):
if len(actual) != len(expected):
return False
actual_set = set(frozenset(d.items()) for d in actual)
expected_set = set(frozenset(d.items()) for d in expected)
return actual_set == expected_set
async def info(self) -> ServerInfo:
return ServerInfo(version="1.0", name="SQL Evaluator", description="Evaluates SQL queries")
if __name__ == "__main__":
server = SQLRewardServer(db_base_path="/path/to/dbs/", port=50056)