#!/usr/bin/python

import sys
import Image
import wx
import wx.aui
import wx.lib.ogl
import operator
import os
import ImageFilter

# ImageSequence?

def pilImage_to_wxBitmap(pilImage):
    """convert pil Image to wx Bitmap with Alpha"""
    
    image = wx.EmptyImage(*pilImage.size)
    # RGB/RGBA are one byte per channel...
    image.SetData(pilImage.convert("RGB").tostring())
    # image.SetAlphaData(pilImage.convert("RGBA").tostring()[3::4])
    image.SetAlphaData(chr(128) * operator.mul(*pilImage.size))
    # adjust?
    image.AdjustChannels(1.0,1.0,1.0,0.8)
    # DrawBitmap wants a bitmap...
    bitmap = wx.BitmapFromImage(image)
    return bitmap


class TestPanel(wx.Panel):
    def __init__(self, parent, imgdir, imgs):
        self.imgdir = imgdir
        self.imgs = imgs
        self.show_imgs = [img.filter(ImageFilter.FIND_EDGES) for img in self.imgs]
        self.bitmaps = dict(((img.filename, pilImage_to_wxBitmap(show_img)) 
                             for img, show_img
                             in zip(self.imgs, self.show_imgs)))
        wx.Panel.__init__(self, parent, -1)
        self.Bind(wx.EVT_PAINT, self.OnPaint)
        self.Bind(wx.EVT_CHAR, self.OnChar)
        # populate x_offset, y_offset
        # use recorded one, else previous one, else zero.
        self.x_offset = {}
        self.y_offset = {}
        self.big_offset = 0
        self.offset_changed = {}
        last_imgname = None
        for img in self.imgs:
            xfile = img.filename + ".x"
            yfile = img.filename + ".y"
            if os.path.exists(xfile):
                self.x_offset[img.filename] = int(file(xfile).read().strip())
            if os.path.exists(yfile):
                self.y_offset[img.filename] = int(file(yfile).read().strip())
            self.offset_changed[img.filename] = False
            last_imgname = img.filename
        # need something for min() later, so if there's nothing, fake up a zero
        if not self.x_offset or not self.y_offset:
            self.x_offset[self.imgs[0].filename] = 0
            self.y_offset[self.imgs[0].filename] = 0
            self.offset_changed[self.imgs[0].img.filename] = True # at least it's never been written out

        self.forward = True
        self.cursor = 0
        self.danger_mode = False
        self.anchor_mode = False

    def save(self):
        for imgname in sorted(self.offset_changed):
            if self.offset_changed[imgname]:
                print imgname, self.x_offset[imgname], self.y_offset[imgname]
                print >> file(imgname + ".x", "w"), self.x_offset[imgname]
                print >> file(imgname + ".y", "w"), self.y_offset[imgname]
                self.offset_changed[imgname] = False

    def paste_forward(self, offset, delta):
        """apply this change if danger_mode is on"""
        if not self.danger_mode:
            return
        for cursor in range(self.cursor + 2, len(self.imgs)):
            imgname = self.imgs[cursor].filename
            offset[imgname] = offset.get(imgname, 0) + delta
            self.offset_changed[imgname] = True

    def OnChar(self, evt):
        this_imgname = self.imgs[self.cursor + 1].filename
        # initial conditions, if we didn't get here by navigation
        # and we don't have a value on disk already, use 0
        if this_imgname not in self.x_offset:
            self.x_offset[this_imgname] = 0
        if this_imgname not in self.y_offset:
            self.y_offset[this_imgname] = 0

        wxk = evt.GetKeyCode()
        if wxk == ord("q"):
            self.save()
            sys.exit()
        elif wxk == ord("s"):
            self.save()
        elif wxk == wx.WXK_UP:
            self.y_offset[this_imgname] -= 1
            self.paste_forward(self.y_offset, -1)
        elif wxk == wx.WXK_DOWN:
            self.y_offset[this_imgname] += 1
            self.paste_forward(self.y_offset, +1)
        elif wxk == wx.WXK_LEFT:
            self.x_offset[this_imgname] -= 1
            self.paste_forward(self.x_offset, -1)
        elif wxk == wx.WXK_RIGHT:
            self.x_offset[this_imgname] += 1
            self.paste_forward(self.x_offset, +1)
        elif wxk == ord("<"):
            # instead of 200, use 1/4 screen?
            self.big_offset = max(0, self.big_offset - 200)
        elif wxk == ord(">"):
            self.big_offset = self.big_offset + 200
        elif wxk == ord(" "):
            self.forward = not self.forward
        elif wxk == wx.WXK_PAGEDOWN:
            self.cursor = min(self.cursor + 1, len(self.imgs) - 2)
        elif wxk == wx.WXK_PAGEUP:
            self.cursor = max(self.cursor - 1, 0)
        elif wxk == ord("!"):
            self.PlayThroughStart()
        elif wxk == ord("^"):
            self.danger_mode = not self.danger_mode
        elif wxk == ord("*"):
            # anchor mode implies danger mode, it's a late-edit thing
            self.anchor_mode = not self.anchor_mode
            self.danger_mode = self.anchor_mode
        elif wxk == wx.WXK_TAB:
            # scan forward to next uncorrected
            for cursor in range(self.cursor + 1, len(self.imgs) - 1):
                if self.imgs[cursor + 1].filename not in self.x_offset:
                    self.cursor = cursor
                    break
        else:
            print "key:", wxk

        if wxk in [wx.WXK_UP, wx.WXK_DOWN, wx.WXK_LEFT, wx.WXK_RIGHT]:
            self.offset_changed[this_imgname] = True
        if wxk in [wx.WXK_PAGEDOWN, wx.WXK_PAGEUP, wx.WXK_TAB]:
            # if we don't have one for the next frame, copy forward this one
            # (later, copy forward a linear velocity correction...)
            next_imgname = self.imgs[self.cursor + 1].filename
            if next_imgname not in self.x_offset:
                self.x_offset[next_imgname] = self.x_offset[this_imgname]
            if next_imgname not in self.y_offset:
                self.y_offset[next_imgname] = self.y_offset[this_imgname]

        self.OnPaint(None)

    def OnPaint(self, evt):
        pdc = wx.BufferedDC(wx.PaintDC(self))

        pdc.SetBackground(wx.Brush("WHITE"))
        if self.danger_mode:
            pdc.SetBackground(wx.Brush("RED"))
        if self.anchor_mode:
            pdc.SetBackground(wx.Brush("GREEN"))

        pdc.Clear()

            
        x_shift = min(self.x_offset.values())
        y_shift = min(self.y_offset.values())
        viewed_images = self.imgs[self.cursor:self.cursor+2]
        if self.anchor_mode:
            viewed_images[0] = self.imgs[0] # or self.anchor...

        viewed_bitmaps = (self.bitmaps[img.filename] for img in viewed_images)
        x_offsets =    (self.x_offset.get(img.filename, 0) for img in viewed_images)
        y_offsets =    (self.y_offset.get(img.filename, 0) for img in viewed_images)
        work = zip(x_offsets, y_offsets, viewed_bitmaps)
        if not self.forward:
            work.reverse()
        for x_offset, y_offset, bitmap in work:
            # handle clipping...
            pdc.DrawBitmap(bitmap, 
                           x_offset - x_shift - self.big_offset, 
                           y_offset - y_shift, 
                           True)
        # label it
        # be lazy and user upper left for now
        label = "%s (%sx%s)" % (os.path.basename(viewed_images[0].filename), x_offset, y_offset)
        drawer = wx.lib.ogl.OpDraw(wx.lib.ogl.DRAWOP_DRAW_TEXT, 0, 0, 0, 0, s=label)
        drawer.Do(pdc, 0, 0)


    def PlayThroughStart(self):
        self.t1 = wx.Timer(self)
        self.t1.Start(1000/15) # 15fps
        self.play_frame = 0
        self.Bind(wx.EVT_TIMER, self.PlayThroughFrame)
        self.Unbind(wx.EVT_CHAR)
        self.Bind(wx.EVT_CHAR, self.PlayThroughKeys)

    def PlayThroughStop(self):
        self.t1.Stop()
        del self.t1
        self.Unbind(wx.EVT_TIMER)
        self.Unbind(wx.EVT_CHAR)
        self.Bind(wx.EVT_CHAR, self.OnChar)

    def PlayThroughKeys(self, evt):
        wxk = evt.GetKeyCode()
        if wxk == wx.WXK_ESCAPE:
            self.PlayThroughStop()
            return
        print "stop key:", wxk

    def PlayThroughFrame(self, evt):
        # we know it's a timer event...
        pdc = wx.BufferedDC(wx.PaintDC(self))
        pdc.SetBackground(wx.Brush("WHITE"))
        pdc.Clear()

        if self.play_frame >= len(self.imgs) - 1:
            self.PlayThroughStop()
            return

        x_shift = min(self.x_offset.values())
        y_shift = min(self.y_offset.values())
        img = self.imgs[self.play_frame]
        bitmap = self.bitmaps[img.filename]
        x_offset = self.x_offset.get(img.filename, 0)
        y_offset = self.y_offset.get(img.filename, 0)
        
        pdc.DrawBitmap(bitmap, 
                       x_offset - x_shift, 
                       y_offset - y_shift, 
                       True)

        self.play_frame += 1


class MyApp(wx.App):
    def __init__(self, imgdir):
        self.imgdir = imgdir
        self.paths = [os.path.join(self.imgdir, img)
                      for img in sorted(os.listdir(imgdir))
                      if img.endswith(".jpg")] # [:150]
        self.imgs = [Image.open(path) for path in self.paths]
        wx.App.__init__(self)

    def OnInit(self):
        self.frame = wx.Frame(None)
        # self.nb = wx.Notebook(self.frame, -1, style=wx.CLIP_CHILDREN)
        # print dir(self.frame)
        self.tp = TestPanel(self.frame, self.imgdir, self.imgs)
        self.frame.Show()
        self.mgr = wx.aui.AuiManager()
        self.mgr.SetManagedWindow(self.frame)
        self.mgr.Update()
        self.tp.SetFocus() # otherwise we need a click to get it

        return True

if __name__ == "__main__":

    MyApp(sys.argv[1]).MainLoop()
