aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/shift.py
diff options
context:
space:
mode:
Diffstat (limited to 'shift.py')
-rw-r--r--shift.py42
1 files changed, 21 insertions, 21 deletions
diff --git a/shift.py b/shift.py
index 76ca09a..ce4353b 100644
--- a/shift.py
+++ b/shift.py
@@ -11,14 +11,14 @@ import cv2
import numpy as np
class Direction(Enum):
- N = 'N'
- NE = 'NE'
- E = 'E'
- SE = 'SE'
- S = 'S'
- SW = 'SW'
- W = 'W'
- NW = 'NW'
+ NORTH = 'N'
+ NORTH_EAST = 'NE'
+ NORTH_WEST = 'NW'
+ SOUTH = 'S'
+ SOUTH_EAST = 'SE'
+ SOUTH_WEST = 'SW'
+ EAST = 'E'
+ WEST = 'W'
def gen_kernel(self, distance):
radius = distance
@@ -26,21 +26,21 @@ class Direction(Enum):
kernel = np.zeros((size, size))
x, y = radius, radius
- if self is Direction.N:
+ if self is Direction.NORTH:
x = -1
- elif self is Direction.NE:
+ elif self is Direction.NORTH_EAST:
x, y = -1, 0
- elif self is Direction.E:
+ elif self is Direction.EAST:
y = 0
- elif self is Direction.SE:
+ elif self is Direction.SOUTH_EAST:
x, y = 0, 0
- elif self is Direction.S:
+ elif self is Direction.SOUTH:
x = 0
- elif self is Direction.SW:
+ elif self is Direction.SOUTH_WEST:
x, y = 0, -1
- elif self is Direction.W:
+ elif self is Direction.WEST:
y = -1
- elif self is Direction.NW:
+ elif self is Direction.NORTH_WEST:
x, y = -1, -1
else:
raise NotImplementedError('unsupported direction: ' + str(self))
@@ -49,7 +49,7 @@ class Direction(Enum):
return kernel
def __str__(self):
- return self.name
+ return self.value
def convolve(img, kernel):
#print(kernel)
@@ -61,11 +61,11 @@ def convolve(img, kernel):
output[i, j] = np.sum(neighborhood * kernel)
return output
-DEFAULT_DIRECTION = Direction.SE
+DEFAULT_DIRECTION = Direction.SOUTH_EAST
DEFAULT_DISTANCE = 1
-def do(img_path, direction=DEFAULT_DIRECTION, distance=DEFAULT_DISTANCE,
- output_path=None):
+def shift(img_path, direction=DEFAULT_DIRECTION, distance=DEFAULT_DISTANCE,
+ output_path=None):
img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
kernel = direction.gen_kernel(distance)
@@ -106,7 +106,7 @@ def _parse_args(args=sys.argv):
return parser.parse_args(args[1:])
def main(args=sys.argv):
- do(**vars(_parse_args(args)))
+ shift(**vars(_parse_args(args)))
if __name__ == '__main__':
main()