triad_challenge_event_2024/bot.py

85 lines
2.5 KiB
Python

import json, asyncio, websockets
import semantic_kernel as sk
from semantic_kernel.connectors.ai.open_ai import OpenAIChatCompletion, OpenAITextEmbedding
api_key, org_id = sk.openai_settings_from_dot_env()
EXIT_CHAT = "\n\nExiting chat..."
async def init_kernel():
kernel = sk.Kernel()
kernel.add_chat_service(
"chat-gpt",
OpenAIChatCompletion(
ai_model_id="gpt-3.5-turbo-1106",
api_key=api_key,
org_id=org_id
)
)
kernel.add_text_embedding_generation_service(
"ada",
OpenAITextEmbedding(
ai_model_id="text-embedding-ada-002",
api_key=api_key,
org_id=org_id
)
)
kernel.register_memory_store(memory_store=sk.memory.VolatileMemoryStore())
kernel.import_plugin(sk.core_plugins.TextMemoryPlugin(), "text_memory")
await populate_memory(kernel)
return kernel
def init_context():
context = kernel.create_new_context()
context["chat_history"] = ""
context[sk.core_plugins.TextMemoryPlugin.COLLECTION_PARAM] = "regulation"
context[sk.core_plugins.TextMemoryPlugin.RELEVANCE_PARAM] = "0.8"
return context
async def populate_memory(kernel):
memories = json.load(open("plugins/mem/trial_regulations.json"))
for id, text in memories.items():
await kernel.memory.save_information(
collection="regulations",
id=id,
text=text
)
async def chat(websocket, kernel, chat_function, context):
try:
websocket.send("User: ")
user_input = websocket.recv()
context["user_input"] = user_input
except KeyboardInterrupt:
websocket.send(EXIT_CHAT)
return False
if user_input == "exit":
websocket.send(EXIT_CHAT)
return False
answer = await kernel.run(chat_function, input_vars=context.variables)
context["chat_history"] += f"\nUser: {user_input}\nSupport: {answer}\n"
websocket.send(f"Support: {answer}")
return True
async def init_chat(websocket, path):
kernel = await init_kernel()
context = init_context()
chat_function = kernel.import_semantic_plugin_from_directory("plugins/", "SupportPlugin")["Chat"]
populate_memory(kernel)
chatting = True
while chatting:
chatting = await chat(kernel, chat_function, context)
server = websockets.serve(init_chat, "localhost", 8000)
asyncio.get_event_loop().run_until_complete(server)
asyncio.get_event_loop().run_forever()