| from fastapi import FastAPI |
| from langgraph.graph import StateGraph |
| from typing import TypedDict, Annotated, List |
| from langgraph.graph.message import add_messages |
| from pydantic import BaseModel |
|
|
| |
| app = FastAPI(title="LangGraph Agent API") |
|
|
| class State(TypedDict): |
| messages: Annotated[list[str], add_messages] |
| current_step: str |
|
|
| class AgentInput(BaseModel): |
| messages: List[str] |
|
|
| def collect_info(state: State) -> dict: |
| print("\n--> In collect_info") |
| print(f"Messages before: {state['messages']}") |
| |
| messages = state["messages"] + ["Information collected"] |
| print(f"Messages after: {messages}") |
| |
| return { |
| "messages": messages, |
| "current_step": "process" |
| } |
|
|
| def process_info(state: State) -> dict: |
| print("\n--> In process_info") |
| print(f"Messages before: {state['messages']}") |
| |
| messages = state["messages"] + ["Information processed"] |
| print(f"Messages after: {messages}") |
| |
| return { |
| "messages": messages, |
| "current_step": "end" |
| } |
|
|
| |
| workflow = StateGraph(State) |
|
|
| |
| workflow.add_node("collect", collect_info) |
| workflow.add_node("process", process_info) |
|
|
| |
| workflow.add_edge("collect", "process") |
|
|
| |
| workflow.set_entry_point("collect") |
| workflow.set_finish_point("process") |
|
|
| |
| agent = workflow.compile() |
|
|
|
|
| @app.post("/run-agent") |
| async def run_agent(input_data: AgentInput): |
| """ |
| Run the agent with the provided input messages. |
| """ |
| initial_state = State(messages=input_data.messages, current_step="collect") |
| final_state = agent.invoke(initial_state) |
| return {"messages": final_state["messages"]} |
|
|
| @app.get("/") |
| async def root(): |
| """ |
| Root endpoint that returns basic API information. |
| """ |
| return {"message": "LangGraph Agent API is running", "endpoints": ["Navigate to https://jstoppa-langgraph-basic-example-api.hf.space/docs#/default/run_agent_run_agent_post to run the example"]} |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=7860) |