import json, asyncio, websockets import semantic_kernel as sk from semantic_kernel.connectors.ai.open_ai import OpenAIChatCompletion, OpenAITextEmbedding api_key, org_id = sk.openai_settings_from_dot_env() EXIT_CHAT = "\n\nExiting chat..." async def init_kernel(): kernel = sk.Kernel() kernel.add_chat_service( "chat-gpt", OpenAIChatCompletion( ai_model_id="gpt-3.5-turbo-1106", api_key=api_key, org_id=org_id ) ) kernel.add_text_embedding_generation_service( "ada", OpenAITextEmbedding( ai_model_id="text-embedding-ada-002", api_key=api_key, org_id=org_id ) ) kernel.register_memory_store(memory_store=sk.memory.VolatileMemoryStore()) kernel.import_plugin(sk.core_plugins.TextMemoryPlugin(), "text_memory") await populate_memory(kernel) return kernel def init_context(): context = kernel.create_new_context() context["chat_history"] = "" context[sk.core_plugins.TextMemoryPlugin.COLLECTION_PARAM] = "regulation" context[sk.core_plugins.TextMemoryPlugin.RELEVANCE_PARAM] = "0.8" return context async def populate_memory(kernel): memories = json.load(open("plugins/mem/trial_regulations.json")) for id, text in memories.items(): await kernel.memory.save_information( collection="regulations", id=id, text=text ) async def chat(websocket, kernel, chat_function, context): try: websocket.send("User: ") user_input = websocket.recv() context["user_input"] = user_input except KeyboardInterrupt: websocket.send(EXIT_CHAT) return False if user_input == "exit": websocket.send(EXIT_CHAT) return False answer = await kernel.run(chat_function, input_vars=context.variables) context["chat_history"] += f"\nUser: {user_input}\nSupport: {answer}\n" websocket.send(f"Support: {answer}") return True async def init_chat(websocket, path): kernel = await init_kernel() context = init_context() chat_function = kernel.import_semantic_plugin_from_directory("plugins/", "SupportPlugin")["Chat"] populate_memory(kernel) chatting = True while chatting: chatting = await chat(kernel, chat_function, context) server = websockets.serve(init_chat, "localhost", 8000) asyncio.get_event_loop().run_until_complete(server) asyncio.get_event_loop().run_forever()