72 lines
2.1 KiB
Python
72 lines
2.1 KiB
Python
# standard imports
|
||
import re
|
||
import urllib.request
|
||
import os
|
||
import logging
|
||
|
||
# local imports
|
||
from cic_auth_helper.error import NotFoundError
|
||
|
||
logg = logging.getLogger(__name__)
|
||
|
||
|
||
re_x = r'^HTTP_(X_.+)$'
|
||
def add_x_headers(env, header_f):
|
||
for x in env:
|
||
m = re.match(re_x, x)
|
||
if m != None:
|
||
header_orig = m[1].replace('_', '-')
|
||
header_f(header_orig, env[x])
|
||
|
||
|
||
class ReverseProxyHandler(urllib.request.BaseHandler):
|
||
|
||
def http_error_404(self, request, response, code, msg, hdrs):
|
||
raise NotFoundError(code, msg, response.getheaders())
|
||
|
||
|
||
class ReverseProxy:
|
||
|
||
def __init__(self, base_url, ignore_proxy_headers=[]):
|
||
self.base_url = base_url
|
||
if not isinstance(ignore_proxy_headers, list):
|
||
raise ValueError('ignore_proxy_headers parameter must be a list of header keys')
|
||
self.ignore_proxy_headers = []
|
||
for h in ignore_proxy_headers:
|
||
self.ignore_proxy_headers.append(h.lower())
|
||
self.opener = urllib.request.build_opener(ReverseProxyHandler())
|
||
|
||
|
||
def proxy_pass(self, env, headers=[]):
|
||
url = os.path.join(self.base_url, env['REQUEST_URI'][1:])
|
||
logg.debug('access ok -> {}'.format(url))
|
||
req = urllib.request.Request(url, method=env['REQUEST_METHOD'])
|
||
add_x_headers(env, req.add_header)
|
||
req.add_header('Content-Type', env.get('CONTENT_TYPE', 'application/octet-stream'))
|
||
req.data = env.get('wsgi.input')
|
||
res = self.opener.open(req)
|
||
|
||
logg.debug('headers before reverse proxy {}'.format(headers))
|
||
|
||
header_keys = {}
|
||
for i, pair in enumerate(headers):
|
||
header_keys[pair[0].lower()] = i
|
||
|
||
for h in res.getheaders():
|
||
k = h[0].lower()
|
||
if k in self.ignore_proxy_headers:
|
||
continue
|
||
try:
|
||
i = header_keys[k]
|
||
headers[i] = h
|
||
except KeyError:
|
||
headers.append(h)
|
||
|
||
logg.debug('headers after reverse proxy {}'.format(headers))
|
||
|
||
status = '{} {}'.format(res.status, res.reason)
|
||
content = res.read()
|
||
return (status, headers, content)
|
||
|
||
|