| """ |
| StackNet API Client |
| |
| Handles all communication with the StackNettask network. |
| SSE parsing and progress tracking are handled internally. |
| """ |
|
|
| import json |
| import tempfile |
| import os |
| from typing import AsyncGenerator, Optional, Any, Callable |
| from dataclasses import dataclass |
| from enum import Enum |
|
|
| import httpx |
|
|
| from ..config import config |
|
|
|
|
| class MediaAction(str, Enum): |
| """Supported media orchestration actions.""" |
| GENERATE_MUSIC = "generate_music" |
| CREATE_COVER = "create_cover" |
| EXTRACT_STEMS = "extract_stems" |
| ANALYZE_VISUAL = "analyze_visual" |
| DESCRIBE_VIDEO = "describe_video" |
| CREATE_COMPOSITE = "create_composite" |
|
|
|
|
| @dataclass |
| class TaskProgress: |
| """Progress update from a running task.""" |
| progress: float |
| status: str |
| message: str |
|
|
|
|
| @dataclass |
| class TaskResult: |
| """Final result from a completed task.""" |
| success: bool |
| data: dict |
| error: Optional[str] = None |
|
|
|
|
| class StackNetClient: |
| """ |
| Client for StackNet task network API. |
| |
| All SSE parsing and polling is handled internally. |
| Consumers receive clean progress updates and final results. |
| """ |
|
|
| def __init__( |
| self, |
| base_url: Optional[str] = None, |
| api_key: Optional[str] = None, |
| timeout: float = 300.0 |
| ): |
| self.base_url = base_url or config.stacknet_url |
| self.api_key = api_key |
| self.timeout = timeout |
| self._temp_dir = tempfile.mkdtemp(prefix="stacknet_") |
|
|
| async def submit_tool_task( |
| self, |
| tool_name: str, |
| parameters: dict, |
| server_name: str = "geoff", |
| on_progress: Optional[Callable[[float, str], None]] = None |
| ) -> TaskResult: |
| """ |
| Submit an MCP tool task and wait for completion. |
| |
| Args: |
| tool_name: The tool to invoke (e.g., generate_image_5) |
| parameters: Tool parameters |
| server_name: MCP server name (default: geoff) |
| on_progress: Callback for progress updates |
| |
| Returns: |
| TaskResult with success status and output data |
| """ |
| payload = { |
| "type": "mcp-tool", |
| "serverName": server_name, |
| "toolName": tool_name, |
| "stream": True, |
| "parameters": parameters |
| } |
|
|
| headers = {"Content-Type": "application/json"} |
| if self.api_key: |
| auth_header = self.api_key if self.api_key.startswith("Bearer ") else f"Bearer {self.api_key}" |
| headers["Authorization"] = auth_header |
|
|
| async with httpx.AsyncClient(timeout=self.timeout) as client: |
| try: |
| async with client.stream( |
| "POST", |
| f"{self.base_url}/tasks", |
| json=payload, |
| headers=headers |
| ) as response: |
| if response.status_code != 200: |
| error_text = await response.aread() |
| return TaskResult( |
| success=False, |
| data={}, |
| error=f"API request failed ({response.status_code}): {error_text.decode()[:200]}" |
| ) |
|
|
| return await self._process_sse_stream(response, on_progress) |
|
|
| except httpx.TimeoutException: |
| return TaskResult( |
| success=False, |
| data={}, |
| error="Request timed out. The operation took too long." |
| ) |
| except httpx.RequestError as e: |
| return TaskResult( |
| success=False, |
| data={}, |
| error=f"Network error: {str(e)}" |
| ) |
|
|
| async def submit_media_task( |
| self, |
| action: MediaAction, |
| prompt: Optional[str] = None, |
| media_url: Optional[str] = None, |
| audio_url: Optional[str] = None, |
| video_url: Optional[str] = None, |
| options: Optional[dict] = None, |
| on_progress: Optional[Callable[[float, str], None]] = None |
| ) -> TaskResult: |
| """ |
| Submit a media orchestration task and wait for completion. |
| |
| Args: |
| action: The media action to perform |
| prompt: Text prompt for generation |
| media_url: URL for image input |
| audio_url: URL for audio input |
| video_url: URL for video input |
| options: Additional options (tags, title, etc.) |
| on_progress: Callback for progress updates (progress: 0-1, message: str) |
| |
| Returns: |
| TaskResult with success status and output data |
| """ |
| payload = { |
| "type": config.TASK_TYPE_MEDIA, |
| "action": action.value, |
| "stream": True, |
| } |
|
|
| if prompt: |
| payload["prompt"] = prompt |
| if media_url: |
| payload["mediaUrl"] = media_url |
| if audio_url: |
| payload["audioUrl"] = audio_url |
| if video_url: |
| payload["videoUrl"] = video_url |
| if options: |
| payload["options"] = options |
|
|
| headers = {"Content-Type": "application/json"} |
| if self.api_key: |
| auth_header = self.api_key if self.api_key.startswith("Bearer ") else f"Bearer {self.api_key}" |
| headers["Authorization"] = auth_header |
|
|
| async with httpx.AsyncClient(timeout=self.timeout) as client: |
| try: |
| async with client.stream( |
| "POST", |
| f"{self.base_url}/tasks", |
| json=payload, |
| headers=headers |
| ) as response: |
| if response.status_code != 200: |
| error_text = await response.aread() |
| return TaskResult( |
| success=False, |
| data={}, |
| error=f"API request failed ({response.status_code}): {error_text.decode()[:200]}" |
| ) |
|
|
| return await self._process_sse_stream(response, on_progress) |
|
|
| except httpx.TimeoutException: |
| return TaskResult( |
| success=False, |
| data={}, |
| error="Request timed out. The operation took too long." |
| ) |
| except httpx.RequestError as e: |
| return TaskResult( |
| success=False, |
| data={}, |
| error=f"Network error: {str(e)}" |
| ) |
|
|
| async def _process_sse_stream( |
| self, |
| response: httpx.Response, |
| on_progress: Optional[Callable[[float, str], None]] = None |
| ) -> TaskResult: |
| """Process SSE stream and extract final result.""" |
| buffer = "" |
| final_result: Optional[dict] = None |
| error_message: Optional[str] = None |
|
|
| async for chunk in response.aiter_text(): |
| buffer += chunk |
| lines = buffer.split("\n") |
| buffer = lines.pop() |
|
|
| for line in lines: |
| if not line.startswith("data: "): |
| continue |
|
|
| raw_data = line[6:].strip() |
|
|
| |
| if raw_data == "[DONE]" or not raw_data: |
| continue |
|
|
| try: |
| event = json.loads(raw_data) |
| event_type = event.get("type", "") |
| event_data = event.get("data", event) |
|
|
| if event_type == "progress": |
| if on_progress: |
| progress = self._calculate_progress(event_data) |
| message = event_data.get("message", "Processing...") |
| on_progress(progress, message) |
|
|
| elif event_type == "result": |
| final_result = event_data.get("output", event_data) |
|
|
| elif event_type == "error": |
| error_message = event_data.get("message", "Unknown error occurred") |
|
|
| elif event_type == "complete": |
| |
| pass |
|
|
| except json.JSONDecodeError: |
| continue |
|
|
| |
| if buffer.strip() and buffer.startswith("data: "): |
| raw_data = buffer[6:].strip() |
| if raw_data and raw_data != "[DONE]": |
| try: |
| event = json.loads(raw_data) |
| if event.get("type") == "result": |
| final_result = event.get("data", {}).get("output", event.get("data", {})) |
| except json.JSONDecodeError: |
| pass |
|
|
| if error_message: |
| return TaskResult(success=False, data={}, error=error_message) |
|
|
| if final_result: |
| return TaskResult(success=True, data=final_result) |
|
|
| return TaskResult( |
| success=False, |
| data={}, |
| error="No result received from the API" |
| ) |
|
|
| def _calculate_progress(self, data: dict) -> float: |
| """Calculate normalized progress (0.0 to 1.0).""" |
| if not data: |
| return 0.5 |
|
|
| status = data.get("status", "") |
|
|
| if status == "completed": |
| return 1.0 |
| if status == "polling": |
| attempt = data.get("attempt", 1) |
| max_attempts = data.get("maxAttempts", 30) |
| return 0.2 + (attempt / max_attempts) * 0.6 |
| if status == "processing": |
| return 0.5 |
| if status == "submitted": |
| return 0.1 |
|
|
| return 0.5 |
|
|
| async def download_file(self, url: str, filename: Optional[str] = None) -> str: |
| """Download a file to the temp directory and return local path.""" |
| if not filename: |
| filename = url.split("/")[-1].split("?")[0] |
| if not filename: |
| filename = "download" |
|
|
| local_path = os.path.join(self._temp_dir, filename) |
|
|
| async with httpx.AsyncClient(timeout=60.0) as client: |
| response = await client.get(url) |
| response.raise_for_status() |
|
|
| with open(local_path, "wb") as f: |
| f.write(response.content) |
|
|
| return local_path |
|
|
| def cleanup(self): |
| """Clean up temporary files.""" |
| import shutil |
| if os.path.exists(self._temp_dir): |
| shutil.rmtree(self._temp_dir, ignore_errors=True) |
|
|