nl2sql-api / backend /src /nl2sql /hf_engine.py
dvwn's picture
Update UI Frontend version 1.2.0 + Backend Integration
a5b3518
# Path: src/nl2sql/hf_engine.py
# This module defines the HuggingFace-based engine for generating SQL queries from natural language questions.
import os
from huggingface_hub import InferenceClient
from langchain_huggingface import HuggingFaceEndpoint
from langchain_core.language_models.llms import LLM
from typing import Any, List, Optional
# Model Registry: Add several model to be tested
MODEL_REGISTRY = {
"defog/sqlcoder-7b-2": "text",
"Qwen/Qwen2.5-Coder-7B-Instruct:featherless-ai": "chat",
"Qwen/Qwen2.5-Coder-32B-Instruct:featherless-ai": "chat",
"defog/llama-3-sqlcoder-8b:featherless-ai": "chat"
#"deepseek-ai/DeepSeek-R1-Distill-Qwen-32B:featherless-ai": "chat"
}
DEFAULT_MODEL_ID = "Qwen/Qwen2.5-Coder-7B-Instruct:featherless-ai"
# Custom LangChain wrapper for HuggingFace Inference API
class HFChatWrapper(LLM):
"""
Custom LLM wrapper for HuggingFace Inference API to maintain compatibility with LangChain's LLM interface.
"""
client: Any
model_id: str
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
completion = self.client.chat.completions.create(
model = self.model_id,
messages = [
{"role": "user", "content": prompt}
],
temperature = 0.0,
max_tokens = 512
)
return completion.choices[0].message.content
@property
def _llm_type(self) -> str:
return "huggingface_inference_client"
def get_models() -> List[str]:
"""Utility to return all model IDs available in the MODEL_REGISTRY."""
return list(MODEL_REGISTRY.keys())
# Initialize the HuggingFace endpoint using the InferenceClient
def get_llm(model_id: str = DEFAULT_MODEL_ID):
"""
Automatically detects the model type and returns the correct LangChain interface.
Initializes the HuggingFace InferenceClient and returns an LLM instance for generating SQL queries.
"""
# Load HuggingFace API token from environment variable
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
raise ValueError("HuggingFace API token not found!")
# Determine the model type based on the MODEL_REGISTRY
active_model = model_id if model_id else DEFAULT_MODEL_ID
if active_model not in MODEL_REGISTRY:
print(f"Warning: Model '{active_model}' not found in MODEL_REGISTRY. Defaulting to 'chat' type.")
model_type = MODEL_REGISTRY.get(active_model, "chat")
print(f"Initializing HuggingFace InferenceClient with model: {active_model}")
if model_type == "chat":
client = InferenceClient(api_key=hf_token)
return HFChatWrapper(client=client, model_id=active_model)
elif model_type == "text":
# Route to standard Text Generation API
return HuggingFaceEndpoint(
repo_id=active_model,
task="text-generation",
max_new_tokens=512,
temperature=0.0,
huggingfacehub_api_token=hf_token,
do_sample=False,
return_full_text=False
)
else:
raise ValueError(f"Unknown model type: {model_type}")
# Initialize the HuggingFace InferenceClient
#client = InferenceClient(api_key=hf_token)
#llm = HFChatWrapper(client=client, model_id=active_model)
#return llm
if __name__=="__main__":
from dotenv import load_dotenv
load_dotenv()
try:
test_llm = get_llm()
print("Model loaded successfully! Running a quick ping...")
response = test_llm.invoke("write a single SQL statement to count all rows in a table name 'Employee'.")
print(f"\nResponse:\n{response}")
except Exception as e:
print(f"Error during LLM initialization: {e}")