from flask import Flask, render_template, request, jsonify, send_file
import pandas as pd
from pathlib import Path
import sys
import matplotlib
import asyncio
import threading
import uuid
from io import BytesIO
ui_path = Path(__file__).parent.resolve()
path_to_add = str(ui_path.parent.parent.parent)
sys.path.append(path_to_add)
from tablemage.agents.api import ChatDA
from tablemage.agents._src.io.canvas import (
CanvasCode,
CanvasFigure,
CanvasTable,
CanvasThought,
)
agent: ChatDA = None
chatda_kwargs = {}
chat_tasks: dict = {}
def chat(msg: str) -> str:
"""
Chat function that processes natural language queries on the uploaded dataset.
"""
global agent
if agent is None:
return "No dataset uploaded. Please upload a dataset first."
else:
return asyncio.run(agent.achat(msg))
def get_analysis():
return agent._canvas_queue.get_analysis()
# Initialize Flask app
flask_app = Flask(__name__)
@flask_app.route("/")
def index():
return render_template("index.html")
@flask_app.route("/upload", methods=["POST"])
def upload_dataset():
"""
Handle dataset upload and store it for the chat function.
"""
global agent
global chatda_kwargs
if "file" not in request.files:
return jsonify({"error": "No file part in the request"}), 400
file = request.files["file"]
if file.filename == "":
return jsonify({"error": "No selected file"}), 400
test_size = request.form.get("test_size", 0.2)
try:
test_size = float(test_size)
if not (0.0 <= test_size <= 1.0):
raise ValueError("Test size must be between 0.0 and 1.0.")
except ValueError as e:
return jsonify({"error": str(e)}), 400
try:
uploaded_data = pd.read_csv(file)
if uploaded_data.columns[0] == "Unnamed: 0":
uploaded_data = uploaded_data.drop(columns="Unnamed: 0")
agent = ChatDA(uploaded_data, test_size=test_size, **chatda_kwargs)
return jsonify({"message": "Dataset uploaded successfully"}), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
@flask_app.route("/chat", methods=["POST"])
def chat_route():
user_message = request.json.get("message")
if not user_message:
return jsonify({"error": "No message provided"}), 400
task_id = str(uuid.uuid4())
chat_tasks[task_id] = {"status": "running", "response": None}
def run_chat():
try:
response_message = chat(user_message)
chat_tasks[task_id]["response"] = response_message
chat_tasks[task_id]["status"] = "done"
except Exception as e:
chat_tasks[task_id]["response"] = f"Error: {e}"
chat_tasks[task_id]["status"] = "error"
thread = threading.Thread(target=run_chat, daemon=True)
thread.start()
return jsonify({"task_id": task_id})
@flask_app.route("/chat/status/<task_id>", methods=["GET"])
def chat_status(task_id):
task = chat_tasks.get(task_id)
if task is None:
return jsonify({"error": "Task not found"}), 404
if task["status"] == "running":
return jsonify({"status": "running"})
else:
response = task["response"]
# Clean up completed task
del chat_tasks[task_id]
return jsonify({"status": task["status"], "response": response})
@flask_app.route("/analysis", methods=["GET"])
def get_analysis_history():
"""
Retrieve the current analysis history (figures, tables, thoughts, code).
"""
if agent is None:
return (
jsonify({"error": "No dataset uploaded. Please upload a dataset first."}),
400,
)
try:
analysis_items = get_analysis()
items = []
for item in analysis_items:
if isinstance(item, CanvasFigure):
path_obj = Path(item.path)
items.append(
{
"file_name": path_obj.name,
"file_type": "figure",
"file_path": str(path_obj),
}
)
elif isinstance(item, CanvasTable):
path_obj = Path(item.path)
df = pd.read_pickle(path_obj)
html_table = df.to_html(classes="table", index=True)
items.append(
{
"file_name": path_obj.name,
"file_type": "table",
"content": html_table,
}
)
elif isinstance(item, CanvasThought):
items.append(
{
"file_type": "thought",
"content": item._thought,
}
)
elif isinstance(item, CanvasCode):
items.append(
{
"file_type": "code",
"content": item._code,
}
)
else:
raise ValueError(f"Unknown item type: {type(item)}")
return jsonify(items)
except Exception as e:
flask_app.logger.error(f"Error retrieving analysis history: {str(e)}")
return jsonify({"error": "Failed to retrieve analysis history"}), 500
@flask_app.route("/analysis/since/<int:index>", methods=["GET"])
def get_analysis_since(index):
"""
Retrieve analysis items added since the given index.
Used for real-time polling during agent processing.
"""
if agent is None:
return jsonify([])
try:
analysis_items = get_analysis()
new_items = analysis_items[index:]
items = []
for item in new_items:
if isinstance(item, CanvasFigure):
path_obj = Path(item.path)
items.append(
{
"file_name": path_obj.name,
"file_type": "figure",
"file_path": str(path_obj),
}
)
elif isinstance(item, CanvasTable):
path_obj = Path(item.path)
df = pd.read_pickle(path_obj)
html_table = df.to_html(classes="table", index=True)
items.append(
{
"file_name": path_obj.name,
"file_type": "table",
"content": html_table,
}
)
elif isinstance(item, CanvasThought):
items.append(
{
"file_type": "thought",
"content": item._thought,
}
)
elif isinstance(item, CanvasCode):
items.append(
{
"file_type": "code",
"content": item._code,
}
)
return jsonify({"items": items, "total": len(analysis_items)})
except Exception as e:
flask_app.logger.error(f"Error retrieving incremental analysis: {str(e)}")
return jsonify({"items": [], "total": index})
@flask_app.route("/analysis/file/<filename>", methods=["GET"])
def serve_file(filename):
"""
Serve static files (figures) from the analysis queue.
"""
if agent is None:
return (
jsonify({"error": "No dataset uploaded. Please upload a dataset first."}),
400,
)
analysis_items = get_analysis()
for item in analysis_items:
if isinstance(item, CanvasFigure) and item._path.name == filename:
file_path = item._path
if file_path.exists():
return send_file(file_path)
return jsonify({"error": f"File '{filename}' not found."}), 404
@flask_app.route("/download_transcript", methods=["GET"])
def download_transcript():
if agent is None:
return (
jsonify({"error": "No dataset uploaded. Please upload a dataset first."}),
400,
)
try:
transcript = agent.get_transcript()
buf = BytesIO(transcript.encode("utf-8"))
buf.seek(0)
return send_file(
buf,
as_attachment=True,
download_name="transcript.txt",
mimetype="text/plain; charset=utf-8",
)
except Exception as e:
flask_app.logger.error(f"Error downloading transcript: {str(e)}")
return jsonify({"error": "Failed to generate transcript"}), 500
[docs]
class ChatDA_UserInterface:
[docs]
def __init__(
self,
split_seed: int | None = None,
system_prompt: str | None = None,
memory_size: int | None = None,
tool_rag: bool | None = None,
tool_rag_top_k: int | None = None,
python_only: bool | None = None,
tools_only: bool | None = None,
multimodal: bool | None = None,
):
"""Makes a user interface for the ChatDA agent.
Parameters
----------
split_seed : int | None
If None, default seed is used.
system_prompt : str | None
If None, default system prompt is used.
memory_size : int | None
If None, default memory size is used.
The size of the buffer.
tool_rag : bool | None
If None, default tool RAG flag is used. \
If True, tool RAG is used. If False, tool RAG is not used,
and all tools are provided to the agent for each query.
tool_rag_top_k : int | None
If None, default tool RAG top k is used.
The number of tools to provide to the agent for each query.
python_only : bool | None
If None, default Python-only flag is used. \
If True, only Python environment is provided. If False, all tools are used.
tools_only : bool | None
If None, default tools-only flag is used. \
If True, only tools are used. If False, all tools are used.
multimodal : bool | None
If None, default multimodal flag is used. \
If True, multimodal model is used for image interpretation.
"""
matplotlib.use("Agg")
global chatda_kwargs
chatda_kwargs = {
"split_seed": split_seed,
"system_prompt": system_prompt,
"memory_size": memory_size,
"tool_rag": tool_rag,
"tool_rag_top_k": tool_rag_top_k,
"python_only": python_only,
"tools_only": tools_only,
"multimodal": multimodal,
}
chatda_kwargs = {k: v for k, v in chatda_kwargs.items() if v is not None}
self.flask_app = flask_app
[docs]
def run(self, host: str = "0.0.0.0", port: str = "5050", debug: bool = False):
"""Runs the Flask app for the ChatDA agent user interface.
Parameters
----------
host : str
The host IP address to run the app on.
port : str
The port number to run the app on.
debug : bool
If True, the app runs in debug mode.
"""
self.flask_app.run(host=host, debug=debug, port=port, threaded=True)