85 lines
2.5 KiB
Python
85 lines
2.5 KiB
Python
|
|
import json, asyncio, websockets
|
||
|
|
import semantic_kernel as sk
|
||
|
|
from semantic_kernel.connectors.ai.open_ai import OpenAIChatCompletion, OpenAITextEmbedding
|
||
|
|
|
||
|
|
api_key, org_id = sk.openai_settings_from_dot_env()
|
||
|
|
EXIT_CHAT = "\n\nExiting chat..."
|
||
|
|
|
||
|
|
async def init_kernel():
|
||
|
|
kernel = sk.Kernel()
|
||
|
|
|
||
|
|
kernel.add_chat_service(
|
||
|
|
"chat-gpt",
|
||
|
|
OpenAIChatCompletion(
|
||
|
|
ai_model_id="gpt-3.5-turbo-1106",
|
||
|
|
api_key=api_key,
|
||
|
|
org_id=org_id
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
kernel.add_text_embedding_generation_service(
|
||
|
|
"ada",
|
||
|
|
OpenAITextEmbedding(
|
||
|
|
ai_model_id="text-embedding-ada-002",
|
||
|
|
api_key=api_key,
|
||
|
|
org_id=org_id
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
kernel.register_memory_store(memory_store=sk.memory.VolatileMemoryStore())
|
||
|
|
kernel.import_plugin(sk.core_plugins.TextMemoryPlugin(), "text_memory")
|
||
|
|
await populate_memory(kernel)
|
||
|
|
|
||
|
|
return kernel
|
||
|
|
|
||
|
|
|
||
|
|
def init_context():
|
||
|
|
context = kernel.create_new_context()
|
||
|
|
context["chat_history"] = ""
|
||
|
|
context[sk.core_plugins.TextMemoryPlugin.COLLECTION_PARAM] = "regulation"
|
||
|
|
context[sk.core_plugins.TextMemoryPlugin.RELEVANCE_PARAM] = "0.8"
|
||
|
|
|
||
|
|
return context
|
||
|
|
|
||
|
|
async def populate_memory(kernel):
|
||
|
|
memories = json.load(open("plugins/mem/trial_regulations.json"))
|
||
|
|
|
||
|
|
for id, text in memories.items():
|
||
|
|
await kernel.memory.save_information(
|
||
|
|
collection="regulations",
|
||
|
|
id=id,
|
||
|
|
text=text
|
||
|
|
)
|
||
|
|
|
||
|
|
async def chat(websocket, kernel, chat_function, context):
|
||
|
|
try:
|
||
|
|
websocket.send("User: ")
|
||
|
|
user_input = websocket.recv()
|
||
|
|
context["user_input"] = user_input
|
||
|
|
except KeyboardInterrupt:
|
||
|
|
websocket.send(EXIT_CHAT)
|
||
|
|
return False
|
||
|
|
|
||
|
|
if user_input == "exit":
|
||
|
|
websocket.send(EXIT_CHAT)
|
||
|
|
return False
|
||
|
|
|
||
|
|
answer = await kernel.run(chat_function, input_vars=context.variables)
|
||
|
|
context["chat_history"] += f"\nUser: {user_input}\nSupport: {answer}\n"
|
||
|
|
websocket.send(f"Support: {answer}")
|
||
|
|
return True
|
||
|
|
|
||
|
|
async def init_chat(websocket, path):
|
||
|
|
kernel = await init_kernel()
|
||
|
|
context = init_context()
|
||
|
|
chat_function = kernel.import_semantic_plugin_from_directory("plugins/", "SupportPlugin")["Chat"]
|
||
|
|
populate_memory(kernel)
|
||
|
|
|
||
|
|
chatting = True
|
||
|
|
while chatting:
|
||
|
|
chatting = await chat(kernel, chat_function, context)
|
||
|
|
|
||
|
|
server = websockets.serve(init_chat, "localhost", 8000)
|
||
|
|
asyncio.get_event_loop().run_until_complete(server)
|
||
|
|
asyncio.get_event_loop().run_forever()
|