Skip to content

Commit 8d54b4c

Browse files
Merge pull request #15 from ryanontheinside/refactor/phase-2
Refactor/phase 2
2 parents 8b6f5f0 + ab03d72 commit 8d54b4c

File tree

14 files changed

+924
-338
lines changed

14 files changed

+924
-338
lines changed

node_wrappers/coordinates/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
"""
2+
Coordinate system node wrappers for ComfyUI.
3+
4+
This package provides node wrappers for the coordinate system functionality,
5+
allowing users to convert between coordinate spaces and create coordinate objects.
6+
"""
7+
8+
from .conversion_nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
9+
10+
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
"""
2+
Node wrappers for coordinate conversion in ComfyUI.
3+
4+
These nodes provide user-friendly interfaces for the coordinate system functionality.
5+
"""
6+
7+
import torch
8+
import logging
9+
from typing import Union, List, Tuple
10+
11+
from ...src.coordinates import CoordinateSystem
12+
13+
logger = logging.getLogger(__name__)
14+
15+
class CoordinateConverterNode:
16+
"""
17+
Fast coordinate conversion between different coordinate spaces.
18+
"""
19+
CATEGORY = "Realtime Nodes/Coordinates"
20+
FUNCTION = "convert_coordinates"
21+
RETURN_TYPES = ("FLOAT", "FLOAT")
22+
RETURN_NAMES = ("x_out", "y_out")
23+
24+
@classmethod
25+
def INPUT_TYPES(cls):
26+
return {
27+
"required": {
28+
"x": ("FLOAT", {"default": 0.0, "forceInput": True, "tooltip": "X coordinate(s) to convert"}),
29+
"y": ("FLOAT", {"default": 0.0, "forceInput": True, "tooltip": "Y coordinate(s) to convert"}),
30+
"image_for_dimensions": ("IMAGE", {"tooltip": "Reference image for dimensions"}),
31+
"from_space": (["pixel", "normalized"],),
32+
"to_space": (["pixel", "normalized"],),
33+
}
34+
}
35+
36+
def convert_coordinates(self, x, y, image_for_dimensions, from_space, to_space):
37+
"""Convert coordinates using the unified coordinate system."""
38+
if image_for_dimensions.dim() != 4:
39+
logger.error("Input image must be BHWC format.")
40+
return ([0.0] * len(x), [0.0] * len(y)) if isinstance(x, list) else (0.0, 0.0)
41+
42+
try:
43+
# Get dimensions from image
44+
dimensions = CoordinateSystem.get_dimensions_from_tensor(image_for_dimensions)
45+
46+
# Map string inputs to CoordinateSystem constants
47+
space_map = {
48+
"pixel": CoordinateSystem.PIXEL,
49+
"normalized": CoordinateSystem.NORMALIZED,
50+
}
51+
52+
from_space_const = space_map[from_space]
53+
to_space_const = space_map[to_space]
54+
55+
# Convert coordinates
56+
x_out = CoordinateSystem.convert(x, dimensions[0], from_space_const, to_space_const)
57+
y_out = CoordinateSystem.convert(y, dimensions[1], from_space_const, to_space_const)
58+
59+
return (x_out, y_out)
60+
except Exception as e:
61+
logger.error(f"Error in coordinate conversion: {e}")
62+
return ([0.0] * len(x), [0.0] * len(y)) if isinstance(x, list) else (0.0, 0.0)
63+
64+
65+
class Point2DNode:
66+
"""
67+
Creates a 2D point with coordinate space awareness.
68+
"""
69+
CATEGORY = "Realtime Nodes/Coordinates"
70+
FUNCTION = "create_point"
71+
RETURN_TYPES = ("POINT",)
72+
RETURN_NAMES = ("point",)
73+
74+
@classmethod
75+
def INPUT_TYPES(cls):
76+
return {
77+
"required": {
78+
"x": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "X coordinate"}),
79+
"y": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Y coordinate"}),
80+
"space": (["normalized", "pixel"],),
81+
},
82+
"optional": {
83+
"image_for_dimensions": ("IMAGE", {"tooltip": "Reference image for dimensions (required for pixel space)"}),
84+
}
85+
}
86+
87+
def create_point(self, x, y, space, image_for_dimensions=None):
88+
"""Create a Point object."""
89+
from ...src.coordinates import Point
90+
91+
space_map = {
92+
"pixel": CoordinateSystem.PIXEL,
93+
"normalized": CoordinateSystem.NORMALIZED,
94+
}
95+
96+
space_const = space_map[space]
97+
98+
# Validate dimensions if using pixel space
99+
if space == "pixel" and image_for_dimensions is None:
100+
logger.warning("Pixel space selected but no image provided for dimensions. Using normalized space.")
101+
space_const = CoordinateSystem.NORMALIZED
102+
103+
return (Point(x, y, None, space_const),)
104+
105+
106+
class PointListNode:
107+
"""
108+
Creates a list of points from coordinate lists.
109+
"""
110+
CATEGORY = "Realtime Nodes/Coordinates"
111+
FUNCTION = "create_point_list"
112+
RETURN_TYPES = ("POINT_LIST",)
113+
RETURN_NAMES = ("point_list",)
114+
115+
@classmethod
116+
def INPUT_TYPES(cls):
117+
return {
118+
"required": {
119+
"x_coords": ("FLOAT", {"default": [0.25, 0.5, 0.75], "forceInput": True, "tooltip": "List of X coordinates"}),
120+
"y_coords": ("FLOAT", {"default": [0.25, 0.5, 0.75], "forceInput": True, "tooltip": "List of Y coordinates"}),
121+
"space": (["normalized", "pixel"],),
122+
},
123+
"optional": {
124+
"z_coords": ("FLOAT", {"default": None, "forceInput": True, "tooltip": "List of Z coordinates (optional)"}),
125+
"image_for_dimensions": ("IMAGE", {"tooltip": "Reference image for dimensions (required for pixel space)"}),
126+
}
127+
}
128+
129+
def create_point_list(self, x_coords, y_coords, space, z_coords=None, image_for_dimensions=None):
130+
"""Create a PointList object."""
131+
from ...src.coordinates import PointList
132+
133+
space_map = {
134+
"pixel": CoordinateSystem.PIXEL,
135+
"normalized": CoordinateSystem.NORMALIZED,
136+
}
137+
138+
space_const = space_map[space]
139+
140+
# Validate dimensions if using pixel space
141+
if space == "pixel" and image_for_dimensions is None:
142+
logger.warning("Pixel space selected but no image provided for dimensions. Using normalized space.")
143+
space_const = CoordinateSystem.NORMALIZED
144+
145+
# Ensure inputs are lists
146+
if not isinstance(x_coords, list):
147+
x_coords = [x_coords]
148+
if not isinstance(y_coords, list):
149+
y_coords = [y_coords]
150+
if z_coords is not None and not isinstance(z_coords, list):
151+
z_coords = [z_coords]
152+
153+
return (PointList.from_coordinates(x_coords, y_coords, z_coords, space_const),)
154+
155+
156+
# Node mapping for ComfyUI
157+
NODE_CLASS_MAPPINGS = {
158+
"CoordinateConverter": CoordinateConverterNode,
159+
"Point2D": Point2DNode,
160+
"PointList": PointListNode,
161+
}
162+
163+
# Display names for ComfyUI
164+
NODE_DISPLAY_NAME_MAPPINGS = {
165+
"CoordinateConverter": "Coordinate Converter",
166+
"Point2D": "Create 2D Point",
167+
"PointList": "Create Point List",
168+
}

node_wrappers/mediapipe_vision/location_landmark/drawing_nodes.py renamed to node_wrappers/coordinates/drawing_nodes.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import logging
33
from typing import Union, List
44

5-
from ....src.mediapipe_vision.location_landmark.drawing_engine import DrawingEngine
5+
from ...src.coordinates import CoordinateSystem
6+
from ...src.coordinates import DrawingEngine
67

78
logger = logging.getLogger(__name__)
89

@@ -35,6 +36,10 @@ def INPUT_TYPES(cls):
3536
def draw_points(self, image: torch.Tensor, x: Union[float, List[float]], y: Union[float, List[float]],
3637
radius: int, color_hex: str, is_normalized: bool = True, batch_mapping: str = "broadcast"):
3738
"""Draw points efficiently using the DrawingEngine."""
39+
# Use the coordinate system to ensure proper handling
40+
space = CoordinateSystem.NORMALIZED if is_normalized else CoordinateSystem.PIXEL
41+
dimensions = CoordinateSystem.get_dimensions_from_tensor(image)
42+
3843
drawing_engine = DrawingEngine()
3944
return (drawing_engine.draw_points(
4045
image=image,
@@ -89,6 +94,10 @@ def draw_lines(self, image: torch.Tensor,
8994
draw_label: bool = False, label_text: str = "Line",
9095
label_position: str = "Midpoint", font_scale: float = 0.5):
9196
"""Draw lines efficiently using the DrawingEngine."""
97+
# Use the coordinate system to ensure proper handling
98+
space = CoordinateSystem.NORMALIZED if is_normalized else CoordinateSystem.PIXEL
99+
dimensions = CoordinateSystem.get_dimensions_from_tensor(image)
100+
92101
drawing_engine = DrawingEngine()
93102
return (drawing_engine.draw_lines(
94103
image=image,
@@ -145,6 +154,10 @@ def draw_polygon(self, image: torch.Tensor,
145154
vertex_radius: int = 3, draw_label: bool = False,
146155
label_text: str = "Polygon", font_scale: float = 0.5):
147156
"""Draw polygon efficiently using the DrawingEngine."""
157+
# Use the coordinate system to ensure proper handling
158+
space = CoordinateSystem.NORMALIZED if is_normalized else CoordinateSystem.PIXEL
159+
dimensions = CoordinateSystem.get_dimensions_from_tensor(image)
160+
148161
drawing_engine = DrawingEngine()
149162
return (drawing_engine.draw_polygon(
150163
image=image,
@@ -186,20 +199,15 @@ def INPUT_TYPES(cls):
186199

187200
def convert_coordinates(self, x: Union[float, List[float]], y: Union[float, List[float]],
188201
image_for_dimensions: torch.Tensor, mode: str):
189-
"""Convert coordinates efficiently using DrawingEngine methods."""
190-
if image_for_dimensions.dim() != 4:
191-
logger.error("Input image must be BHWC format.")
192-
return ([0.0] * len(x), [0.0] * len(y)) if isinstance(x, list) else (0.0, 0.0)
193-
194-
_, height, width, _ = image_for_dimensions.shape
195-
drawing_engine = DrawingEngine()
202+
"""Convert coordinates using the new coordinate system."""
203+
dimensions = CoordinateSystem.get_dimensions_from_tensor(image_for_dimensions)
196204

197205
if mode == "Pixel to Normalized":
198-
x_out = drawing_engine.normalize_coords(x, width)
199-
y_out = drawing_engine.normalize_coords(y, height)
206+
x_out = CoordinateSystem.normalize(x, dimensions[0], CoordinateSystem.PIXEL)
207+
y_out = CoordinateSystem.normalize(y, dimensions[1], CoordinateSystem.PIXEL)
200208
else: # "Normalized to Pixel"
201-
x_out = drawing_engine.denormalize_coords(x, width)
202-
y_out = drawing_engine.denormalize_coords(y, height)
209+
x_out = CoordinateSystem.denormalize(x, dimensions[0], CoordinateSystem.PIXEL)
210+
y_out = CoordinateSystem.denormalize(y, dimensions[1], CoordinateSystem.PIXEL)
203211

204212
return (x_out, y_out)
205213

0 commit comments

Comments
 (0)