caffe run


原文链接: caffe run
#!/usr/bin/env python
#coding=utf-8
mport numpy as np  
import sys,os  
import cv2
# caffe_root = '/home/yaochuanqi/work/tmp/ssd/'
# sys.path.insert(0, caffe_root + 'python')  
import caffe  


net_file= 'deploy.prototxt'
caffe_model='deploy.caffemodel'  
test_dir = "images"

caffe.set_mode_cpu()
net = caffe.Net(net_file,caffe_model,caffe.TEST)  

COCO_CLASSES = ("background" , "person" , "bicycle" , "car" , "motorcycle" , 
     "airplane" , "bus" , "train" , "truck" , "boat" , "traffic light",
     "fire hydrant", "N/A" , "stop sign", "parking meter", "bench" , 
     "bird" , "cat" , "dog" , "horse" , "sheep" , "cow" , "elephant" , 
     "bear" , "zebra" , "giraffe" , "N/A" , "backpack" , "umbrella" , 
     "N/A" , "N/A" , "handbag" , "tie" , "suitcase" , "frisbee" , "skis" ,
     "snowboard" , "sports ball", "kite" , "baseball bat", "baseball glove",
     "skateboard" , "surfboard" , "tennis racket", "bottle" , "N/A" ,
     "wine glass", "cup" , "fork" , "knife" , "spoon" , "bowl" , "banana" ,
     "apple" , "sandwich" , "orange" , "broccoli" , "carrot" , "hot dog",
     "pizza" , "donut" , "cake" , "chair" , "couch" , "potted plant", 
     "bed" , "N/A" , "dining table", "N/A" , "N/A" , "toilet" , "N/A" ,
     "tv" , "laptop" , "mouse" , "remote" , "keyboard" , "cell phone",
     "microwave" , "oven" , "toaster" , "sink" , "refrigerator" , "N/A" ,
     "book" , "clock" , "vase" , "scissors" , "teddy bear", "hair drier",
     "toothbrush" )

def preprocess(src):
    img = cv2.resize(src, (300,300))
    img = img - 127.5
    img = img / 127.5
    return img

def postprocess(img, out):   
    h = img.shape[0]
    w = img.shape[1]
    box = out['detection_out'][0,0,:,3:7] * np.array([w, h, w, h])

    cls = out['detection_out'][0,0,:,1]
    conf = out['detection_out'][0,0,:,2]
    return (box.astype(np.int32), conf, cls)

def detect(imgfile):
    origimg = cv2.imread(imgfile)
    img = preprocess(origimg)
    
    img = img.astype(np.float32)
    img = img.transpose((2, 0, 1))

    net.blobs['data'].data[...] = img
    out = net.forward() 
    box, conf, cls = postprocess(origimg, out)

    for i in range(len(box)):
       p1 = (box[i][0], box[i][1])
       p2 = (box[i][2], box[i][3])
       cv2.rectangle(origimg, p1, p2, (0,255,0))
       p3 = (max(p1[0], 15), max(p1[1], 15))
       title = "%s:%.2f" % (COCO_CLASSES[int(cls[i])], conf[i])
       cv2.putText(origimg, title, p3, cv2.FONT_ITALIC, 0.6, (0, 255, 0), 1)
    cv2.imshow("SSD", origimg)
 
    k = cv2.waitKey(0) & 0xff
        #Exit if ESC pressed
    if k == 27 : return False
    return True

for f in os.listdir(test_dir):
    if detect(test_dir + "/" + f) == False:
       break
`