001/**
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.activemq.web;
019
020import java.io.Externalizable;
021import java.io.IOException;
022import java.io.ObjectInput;
023import java.io.ObjectOutput;
024import java.util.ArrayList;
025import java.util.HashMap;
026import java.util.Iterator;
027import java.util.List;
028import java.util.Map;
029import java.util.concurrent.Semaphore;
030
031import javax.jms.Connection;
032import javax.jms.ConnectionFactory;
033import javax.jms.DeliveryMode;
034import javax.jms.Destination;
035import javax.jms.JMSException;
036import javax.jms.Message;
037import javax.jms.MessageConsumer;
038import javax.jms.MessageProducer;
039import javax.jms.Session;
040import javax.servlet.ServletContext;
041import javax.servlet.http.HttpServletRequest;
042import javax.servlet.http.HttpSession;
043import javax.servlet.http.HttpSessionActivationListener;
044import javax.servlet.http.HttpSessionBindingEvent;
045import javax.servlet.http.HttpSessionBindingListener;
046import javax.servlet.http.HttpSessionEvent;
047
048import org.apache.activemq.ActiveMQConnectionFactory;
049import org.apache.activemq.MessageAvailableConsumer;
050import org.apache.activemq.broker.BrokerRegistry;
051import org.apache.activemq.broker.BrokerService;
052import org.slf4j.Logger;
053import org.slf4j.LoggerFactory;
054
055/**
056 * Represents a messaging client used from inside a web container typically
057 * stored inside a HttpSession TODO controls to prevent DOS attacks with users
058 * requesting many consumers TODO configure consumers with small prefetch.
059 * 
060 *
061 *
062 */
063public class WebClient implements HttpSessionActivationListener, HttpSessionBindingListener, Externalizable {
064
065    public static final String WEB_CLIENT_ATTRIBUTE = "org.apache.activemq.webclient";
066    public static final String CONNECTION_FACTORY_ATTRIBUTE = "org.apache.activemq.connectionFactory";
067    public static final String CONNECTION_FACTORY_PREFETCH_PARAM = "org.apache.activemq.connectionFactory.prefetch";
068    public static final String CONNECTION_FACTORY_OPTIMIZE_ACK_PARAM = "org.apache.activemq.connectionFactory.optimizeAck";
069    public static final String BROKER_URL_INIT_PARAM = "org.apache.activemq.brokerURL";
070    public static final String USERNAME_INIT_PARAM = "org.apache.activemq.username";
071    public static final String PASSWORD_INIT_PARAM = "org.apache.activemq.password";
072    public static final String SELECTOR_NAME = "org.apache.activemq.selectorName";
073
074    private static final Logger LOG = LoggerFactory.getLogger(WebClient.class);
075
076    private static transient ActiveMQConnectionFactory factory;
077
078    private transient Map<Destination, MessageConsumer> consumers = new HashMap<Destination, MessageConsumer>();
079    private transient Connection connection;
080    private transient Session session;
081    private transient MessageProducer producer;
082    private int deliveryMode = DeliveryMode.NON_PERSISTENT;
083    public static String selectorName;
084
085    private final Semaphore semaphore = new Semaphore(1);
086
087    private String username;
088    private String password;
089
090    public WebClient() {
091        if (factory == null) {
092            throw new IllegalStateException("initContext(ServletContext) not called");
093        }
094    }
095
096    /**
097     * Helper method to get the client for the current session, lazily creating
098     * a client if there is none currently
099     * 
100     * @param request is the current HTTP request
101     * @return the current client or a newly creates
102     */
103    public static WebClient getWebClient(HttpServletRequest request) {
104        HttpSession session = request.getSession(true);
105        WebClient client = getWebClient(session);
106        if (client == null || client.isClosed()) {
107            client = WebClient.createWebClient(request);
108            session.setAttribute(WEB_CLIENT_ATTRIBUTE, client);
109        }
110
111        return client;
112    }
113
114    /**
115     * @return the web client for the current HTTP session or null if there is
116     *         not a web client created yet
117     */
118    public static WebClient getWebClient(HttpSession session) {
119        return (WebClient)session.getAttribute(WEB_CLIENT_ATTRIBUTE);
120    }
121
122    public static void initContext(ServletContext context) {
123        initConnectionFactory(context);
124        context.setAttribute("webClients", new HashMap<String, WebClient>());
125        if (selectorName == null) {
126            selectorName = context.getInitParameter(SELECTOR_NAME);
127        }
128        if (selectorName == null) {
129            selectorName = "selector";
130        }        
131    }
132
133    public int getDeliveryMode() {
134        return deliveryMode;
135    }
136
137    public void setDeliveryMode(int deliveryMode) {
138        this.deliveryMode = deliveryMode;
139    }
140
141    public String getUsername() {
142        return username;
143    }
144
145    public void setUsername(String username) {
146        this.username = username;
147    }
148
149    public String getPassword() {
150        return password;
151    }
152
153    public void setPassword(String password) {
154        this.password = password;
155    }
156
157    public synchronized void closeConsumers() {
158        for (Iterator<MessageConsumer> it = consumers.values().iterator(); it.hasNext();) {
159            MessageConsumer consumer = it.next();
160            it.remove();
161            try {
162                consumer.setMessageListener(null);
163                if (consumer instanceof MessageAvailableConsumer) {
164                    ((MessageAvailableConsumer)consumer).setAvailableListener(null);
165                }
166                consumer.close();
167            } catch (JMSException e) {
168                LOG.debug("caught exception closing consumer", e);
169            }
170        }
171    }
172
173    public synchronized void close() {
174        try {
175            if (consumers != null) {
176                closeConsumers();
177            }
178            if (connection != null) {
179                connection.close();
180            }
181        } catch (Exception e) {
182            LOG.debug("caught exception closing consumer", e);
183        } finally {
184            producer = null;
185            session = null;
186            connection = null;
187            if (consumers != null) {
188                consumers.clear();
189            }
190            consumers = null;
191
192        }
193    }
194
195    public boolean isClosed() {
196        return consumers == null;
197    }
198
199    public void writeExternal(ObjectOutput out) throws IOException {
200        if (consumers != null) {
201            out.write(consumers.size());
202            Iterator<Destination> i = consumers.keySet().iterator();
203            while (i.hasNext()) {
204                out.writeObject(i.next().toString());
205            }
206        } else {
207            out.write(-1);
208        }
209
210    }
211
212    public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
213        int size = in.readInt();
214        if (size >= 0) {
215            consumers = new HashMap<Destination, MessageConsumer>();
216            for (int i = 0; i < size; i++) {
217                String destinationName = in.readObject().toString();
218
219                try {
220                    Destination destination = destinationName.startsWith("topic://") ? (Destination)getSession().createTopic(destinationName) : (Destination)getSession().createQueue(destinationName);
221                    consumers.put(destination, getConsumer(destination, null, true));
222                } catch (JMSException e) {
223                    LOG.debug("Caought Exception ", e);
224                    IOException ex = new IOException(e.getMessage());
225                    ex.initCause(e.getCause() != null ? e.getCause() : e);
226                    throw ex;
227
228                }
229            }
230        }
231    }
232
233    public void send(Destination destination, Message message) throws JMSException {
234        getProducer().send(destination, message);
235        if (LOG.isDebugEnabled()) {
236            LOG.debug("Sent! to destination: " + destination + " message: " + message);
237        }
238    }
239
240    public void send(Destination destination, Message message, boolean persistent, int priority, long timeToLive) throws JMSException {
241        int deliveryMode = persistent ? DeliveryMode.PERSISTENT : DeliveryMode.NON_PERSISTENT;
242        getProducer().send(destination, message, deliveryMode, priority, timeToLive);
243        if (LOG.isDebugEnabled()) {
244            LOG.debug("Sent! to destination: " + destination + " message: " + message);
245        }
246    }
247
248    public Session getSession() throws JMSException {
249        if (session == null) {
250            session = createSession();
251        }
252        return session;
253    }
254
255    public Connection getConnection() throws JMSException {
256        if (connection == null) {
257            if (username != null && password != null) {
258                connection = factory.createConnection(username, password);
259            } else {
260                connection = factory.createConnection();
261            }
262            connection.start();
263        }
264        return connection;
265    }
266
267    protected static synchronized void initConnectionFactory(ServletContext servletContext) {
268        if (factory == null) {
269            factory = (ActiveMQConnectionFactory)servletContext.getAttribute(CONNECTION_FACTORY_ATTRIBUTE);
270        }
271        if (factory == null) {
272            String brokerURL = getInitParameter(servletContext, BROKER_URL_INIT_PARAM);
273
274
275            if (brokerURL == null) {
276                LOG.debug("Couldn't find " + BROKER_URL_INIT_PARAM + " param, trying to find a broker embedded in a local VM");
277                BrokerService broker = BrokerRegistry.getInstance().findFirst();
278                if (broker == null) {
279                    throw new IllegalStateException("missing brokerURL (specified via " + BROKER_URL_INIT_PARAM + " init-Param) or embedded broker");
280                } else {
281                    brokerURL = "vm://" + broker.getBrokerName();
282                }
283            }
284
285            LOG.debug("Using broker URL: " + brokerURL);
286            String username = getInitParameter(servletContext, USERNAME_INIT_PARAM);
287            String password = getInitParameter(servletContext, PASSWORD_INIT_PARAM);
288            ActiveMQConnectionFactory amqfactory = new ActiveMQConnectionFactory(username, password, brokerURL);
289
290            // Set prefetch policy for factory
291            if (servletContext.getInitParameter(CONNECTION_FACTORY_PREFETCH_PARAM) != null) {
292                int prefetch = Integer.valueOf(getInitParameter(servletContext, CONNECTION_FACTORY_PREFETCH_PARAM)).intValue();
293                amqfactory.getPrefetchPolicy().setAll(prefetch);
294            }
295
296            // Set optimize acknowledge setting
297            if (servletContext.getInitParameter(CONNECTION_FACTORY_OPTIMIZE_ACK_PARAM) != null) {
298                boolean optimizeAck = Boolean.valueOf(getInitParameter(servletContext, CONNECTION_FACTORY_OPTIMIZE_ACK_PARAM)).booleanValue();
299                amqfactory.setOptimizeAcknowledge(optimizeAck);
300            }
301
302            factory = amqfactory;
303
304            servletContext.setAttribute(CONNECTION_FACTORY_ATTRIBUTE, factory);
305        }
306    }
307
308    private static String getInitParameter(ServletContext servletContext, String initParam) {
309        String result = servletContext.getInitParameter(initParam);
310        if(result != null && result.startsWith("${") && result.endsWith("}"))
311        {
312            result = System.getProperty(result.substring(2,result.length()-1));
313        }
314        return result;
315    }
316
317    public synchronized MessageProducer getProducer() throws JMSException {
318        if (producer == null) {
319            producer = getSession().createProducer(null);
320            producer.setDeliveryMode(deliveryMode);
321        }
322        return producer;
323    }
324
325    public void setProducer(MessageProducer producer) {
326        this.producer = producer;
327    }
328
329    public synchronized MessageConsumer getConsumer(Destination destination, String selector) throws JMSException {
330        return getConsumer(destination, selector, true);
331    }
332
333    public synchronized MessageConsumer getConsumer(Destination destination, String selector, boolean create) throws JMSException {
334        MessageConsumer consumer = consumers.get(destination);
335        if (create && consumer == null) {
336            consumer = getSession().createConsumer(destination, selector);
337            consumers.put(destination, consumer);
338        }
339        return consumer;
340    }
341
342    public synchronized void closeConsumer(Destination destination) throws JMSException {
343        MessageConsumer consumer = consumers.get(destination);
344        if (consumer != null) {
345            consumers.remove(destination);
346            consumer.setMessageListener(null);
347            if (consumer instanceof MessageAvailableConsumer) {
348                ((MessageAvailableConsumer)consumer).setAvailableListener(null);
349            }
350            consumer.close();
351        }
352    }
353
354    public synchronized List<MessageConsumer> getConsumers() {
355        return new ArrayList<MessageConsumer>(consumers.values());
356    }
357
358    protected Session createSession() throws JMSException {
359        return getConnection().createSession(false, Session.AUTO_ACKNOWLEDGE);
360    }
361
362    public Semaphore getSemaphore() {
363        return semaphore;
364    }
365
366    public void sessionWillPassivate(HttpSessionEvent event) {
367        close();
368    }
369
370    public void sessionDidActivate(HttpSessionEvent event) {
371    }
372
373    public void valueBound(HttpSessionBindingEvent event) {
374    }
375
376    public void valueUnbound(HttpSessionBindingEvent event) {
377        close();
378    }
379
380    protected static WebClient createWebClient(HttpServletRequest request) {
381        WebClient client = new WebClient();
382
383        String auth = request.getHeader("Authorization");
384        if (factory.getUserName() == null && factory.getPassword() == null && auth != null) {
385            String[] tokens = auth.split(" ");
386            if (tokens.length == 2) {
387                String encoded = tokens[1].trim();
388                String credentials = new String(javax.xml.bind.DatatypeConverter.parseBase64Binary(encoded));
389                String[] creds = credentials.split(":");
390                if (creds.length == 2) {
391                    client.setUsername(creds[0]);
392                    client.setPassword(creds[1]);
393                }
394            }
395        }
396        return client;
397    }
398
399}