32 lines
1.0 KiB
Python
32 lines
1.0 KiB
Python
from langchain.chat_models import ChatOpenAI
|
|
from langchain.prompts import ChatPromptTemplate
|
|
from langchain.chains import LLMChain
|
|
from typing import Dict, Any
|
|
import asyncio
|
|
|
|
class SlotFiller:
|
|
def __init__(self):
|
|
self.llm = ChatOpenAI(temperature=0)
|
|
self.prompt = ChatPromptTemplate.from_template("""
|
|
You are a helpful assistant. Given a message and a schema, extract all known values.
|
|
|
|
Only return a JSON object containing the extracted values and no extra text.
|
|
|
|
Schema: {schema}
|
|
Message: {message}
|
|
""")
|
|
self.chain = LLMChain(llm=self.llm, prompt=self.prompt)
|
|
|
|
async def extract_slots(self, schema: Dict[str, Any], message: str) -> Dict[str, Any]:
|
|
loop = asyncio.get_event_loop()
|
|
result = await loop.run_in_executor(None, self.chain.run, {
|
|
"schema": schema,
|
|
"message": message
|
|
})
|
|
|
|
import json
|
|
try:
|
|
return json.loads(result)
|
|
except Exception:
|
|
return {}
|