Subversion Repositories SmartDukaan

Rev

Rev 30 | Blame | Compare with Previous | Last modification | View Log | RSS feed

%%
%% Licensed to the Apache Software Foundation (ASF) under one
%% or more contributor license agreements. See the NOTICE file
%% distributed with this work for additional information
%% regarding copyright ownership. The ASF licenses this file
%% to you under the Apache License, Version 2.0 (the
%% "License"); you may not use this file except in compliance
%% with the License. You may obtain a copy of the License at
%%
%%   http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing,
%% software distributed under the License is distributed on an
%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
%% KIND, either express or implied. See the License for the
%% specific language governing permissions and limitations
%% under the License.
%%

-module(thrift_client).

-behaviour(gen_server).

%% API
-export([start_link/2, start_link/3, start_link/4,
         start/3, start/4,
         call/3, send_call/3, close/1]).

%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
         terminate/2, code_change/3]).


-include("thrift_constants.hrl").
-include("thrift_protocol.hrl").

-record(state, {service, protocol, seqid}).

%%====================================================================
%% API
%%====================================================================
%%--------------------------------------------------------------------
%% Function: start_link() -> {ok,Pid} | ignore | {error,Error}
%% Description: Starts the server as a linked process.
%%--------------------------------------------------------------------
start_link(Host, Port, Service) when is_integer(Port), is_atom(Service) ->
    start_link(Host, Port, Service, []).

start_link(Host, Port, Service, Options) ->
    start(Host, Port, Service, [{monitor, link} | Options]).

start_link(ProtocolFactory, Service) ->
    start(ProtocolFactory, Service, [{monitor, link}]).

%%
%% Splits client options into protocol options and transport options
%%
%% split_options([Options...]) -> {ProtocolOptions, TransportOptions}
%%
split_options(Options) ->
    split_options(Options, [], [], []).

split_options([], ClientIn, ProtoIn, TransIn) ->
    {ClientIn, ProtoIn, TransIn};

split_options([Opt = {OptKey, _} | Rest], ClientIn, ProtoIn, TransIn)
  when OptKey =:= monitor ->
    split_options(Rest, [Opt | ClientIn], ProtoIn, TransIn);

split_options([Opt = {OptKey, _} | Rest], ClientIn, ProtoIn, TransIn)
  when OptKey =:= strict_read;
       OptKey =:= strict_write ->
    split_options(Rest, ClientIn, [Opt | ProtoIn], TransIn);

split_options([Opt = {OptKey, _} | Rest], ClientIn, ProtoIn, TransIn)
  when OptKey =:= framed;
       OptKey =:= connect_timeout;
       OptKey =:= sockopts ->
    split_options(Rest, ClientIn, ProtoIn, [Opt | TransIn]).


%%--------------------------------------------------------------------
%% Function: start() -> {ok,Pid} | ignore | {error,Error}
%% Description: Starts the server as an unlinked process.
%%--------------------------------------------------------------------

%% Backwards-compatible starter for the common-case of socket transports
start(Host, Port, Service, Options)
  when is_integer(Port), is_atom(Service), is_list(Options) ->
    {ClientOpts, ProtoOpts, TransOpts} = split_options(Options),

    {ok, TransportFactory} =
        thrift_socket_transport:new_transport_factory(Host, Port, TransOpts),

    {ok, ProtocolFactory} = thrift_binary_protocol:new_protocol_factory(
                              TransportFactory, ProtoOpts),

    start(ProtocolFactory, Service, ClientOpts).


%% ProtocolFactory :: fun() -> thrift_protocol()
start(ProtocolFactory, Service, ClientOpts)
  when is_function(ProtocolFactory), is_atom(Service) ->
    {Starter, Opts} =
        case lists:keysearch(monitor, 1, ClientOpts) of
            {value, {monitor, link}} ->
                {start_link, []};
            {value, {monitor, tether}} ->
                {start, [{tether, self()}]};
            _ ->
                {start, []}
        end,

    Connect =
        case lists:keysearch(connect, 1, ClientOpts) of
            {value, {connect, Choice}} ->
                Choice;
            _ ->
                %% By default, connect at creation-time.
                true
        end,


    Started = gen_server:Starter(?MODULE, [Service, Opts], []),

    if
        Connect ->
            case Started of
                {ok, Pid} ->
                    case gen_server:call(Pid, {connect, ProtocolFactory}) of
                        ok ->
                            {ok, Pid};
                        Error ->
                            Error
                    end;
                Else ->
                    Else
            end;
        true ->
            Started
    end.

call(Client, Function, Args)
  when is_pid(Client), is_atom(Function), is_list(Args) ->
    case gen_server:call(Client, {call, Function, Args}) of
        R = {ok, _} -> R;
        R = {error, _} -> R;
        {exception, Exception} -> throw(Exception)
    end.

cast(Client, Function, Args)
  when is_pid(Client), is_atom(Function), is_list(Args) ->
    gen_server:cast(Client, {call, Function, Args}).

%% Sends a function call but does not read the result. This is useful
%% if you're trying to log non-oneway function calls to write-only
%% transports like thrift_disk_log_transport.
send_call(Client, Function, Args)
  when is_pid(Client), is_atom(Function), is_list(Args) ->
    gen_server:call(Client, {send_call, Function, Args}).

close(Client) when is_pid(Client) ->
    gen_server:cast(Client, close).

%%====================================================================
%% gen_server callbacks
%%====================================================================

%%--------------------------------------------------------------------
%% Function: init(Args) -> {ok, State} |
%%                         {ok, State, Timeout} |
%%                         ignore               |
%%                         {stop, Reason}
%% Description: Initiates the server
%%--------------------------------------------------------------------
init([Service, Opts]) ->
    case lists:keysearch(tether, 1, Opts) of
        {value, {tether, Pid}} ->
            erlang:monitor(process, Pid);
        _Else ->
            ok
    end,
    {ok, #state{service = Service}}.

%%--------------------------------------------------------------------
%% Function: %% handle_call(Request, From, State) -> {reply, Reply, State} |
%%                                      {reply, Reply, State, Timeout} |
%%                                      {noreply, State} |
%%                                      {noreply, State, Timeout} |
%%                                      {stop, Reason, Reply, State} |
%%                                      {stop, Reason, State}
%% Description: Handling call messages
%%--------------------------------------------------------------------
handle_call({connect, ProtocolFactory}, _From,
            State = #state{service = Service}) ->
    case ProtocolFactory() of
        {ok, Protocol} ->
            {reply, ok, State#state{protocol = Protocol,
                                    seqid = 0}};
        Error ->
            {stop, normal, Error, State}
    end;

handle_call({call, Function, Args}, _From, State = #state{service = Service}) ->
    Result = catch_function_exceptions(
               fun() ->
                       ok = send_function_call(State, Function, Args),
                       receive_function_result(State, Function)
               end,
               Service),
    {reply, Result, State};


handle_call({send_call, Function, Args}, _From, State = #state{service = Service}) ->
    Result = catch_function_exceptions(
               fun() ->
                       send_function_call(State, Function, Args)
               end,
               Service),
    {reply, Result, State}.


%% Helper function that catches exceptions thrown by sending or receiving
%% a function and returns the correct response for call or send_only above.
catch_function_exceptions(Fun, Service) ->
    try
        Fun()
    catch
        throw:{return, Return} ->
            Return;
          error:function_clause ->
            ST = erlang:get_stacktrace(),
            case hd(ST) of
                {Service, function_info, [Function, _]} ->
                    {error, {no_function, Function}};
                _ -> throw({error, {function_clause, ST}})
            end
    end.


%%--------------------------------------------------------------------
%% Function: handle_cast(Msg, State) -> {noreply, State} |
%%                                      {noreply, State, Timeout} |
%%                                      {stop, Reason, State}
%% Description: Handling cast messages
%%--------------------------------------------------------------------
handle_cast({call, Function, Args}, State = #state{service = Service,
                                                   protocol = Protocol,
                                                   seqid = SeqId}) ->
    _Result =
        try
            ok = send_function_call(State, Function, Args),
            receive_function_result(State, Function)
        catch
            Class:Reason ->
                error_logger:error_msg("error ignored in handle_cast({cast,...},...): ~p:~p~n", [Class, Reason])
        end,

    {noreply, State};

handle_cast(close, State=#state{protocol = Protocol}) ->
%%     error_logger:info_msg("thrift_client ~p received close", [self()]),
    {stop,normal,State};
handle_cast(_Msg, State) ->
    {noreply, State}.

%%--------------------------------------------------------------------
%% Function: handle_info(Info, State) -> {noreply, State} |
%%                                       {noreply, State, Timeout} |
%%                                       {stop, Reason, State}
%% Description: Handling all non call/cast messages
%%--------------------------------------------------------------------
handle_info({'DOWN', MonitorRef, process, Pid, _Info}, State)
  when is_reference(MonitorRef), is_pid(Pid) ->
    %% We don't actually verify the correctness of the DOWN message.
    {stop, parent_died, State};

handle_info(_Info, State) ->
    {noreply, State}.

%%--------------------------------------------------------------------
%% Function: terminate(Reason, State) -> void()
%% Description: This function is called by a gen_server when it is about to
%% terminate. It should be the opposite of Module:init/1 and do any necessary
%% cleaning up. When it returns, the gen_server terminates with Reason.
%% The return value is ignored.
%%--------------------------------------------------------------------
terminate(Reason, State = #state{protocol=undefined}) ->
    ok;
terminate(Reason, State = #state{protocol=Protocol}) ->
    thrift_protocol:close_transport(Protocol),
    ok.

%%--------------------------------------------------------------------
%% Func: code_change(OldVsn, State, Extra) -> {ok, NewState}
%% Description: Convert process state when code is changed
%%--------------------------------------------------------------------
code_change(_OldVsn, State, _Extra) ->
    {ok, State}.

%%--------------------------------------------------------------------
%%% Internal functions
%%--------------------------------------------------------------------
send_function_call(#state{protocol = Proto,
                          service  = Service,
                          seqid    = SeqId},
                   Function,
                   Args) ->
    Params = Service:function_info(Function, params_type),
    {struct, PList} = Params,
    if
        length(PList) =/= length(Args) ->
            throw({return, {error, {bad_args, Function, Args}}});
        true -> ok
    end,

    Begin = #protocol_message_begin{name = atom_to_list(Function),
                                    type = ?tMessageType_CALL,
                                    seqid = SeqId},
    ok = thrift_protocol:write(Proto, Begin),
    ok = thrift_protocol:write(Proto, {Params, list_to_tuple([Function | Args])}),
    ok = thrift_protocol:write(Proto, message_end),
    thrift_protocol:flush_transport(Proto),
    ok.

receive_function_result(State = #state{protocol = Proto,
                                       service = Service},
                        Function) ->
    ResultType = Service:function_info(Function, reply_type),
    read_result(State, Function, ResultType).

read_result(_State,
            _Function,
            oneway_void) ->
    {ok, ok};

read_result(State = #state{protocol = Proto,
                           seqid    = SeqId},
            Function,
            ReplyType) ->
    case thrift_protocol:read(Proto, message_begin) of
        #protocol_message_begin{seqid = RetSeqId} when RetSeqId =/= SeqId ->
            {error, {bad_seq_id, SeqId}};

        #protocol_message_begin{type = ?tMessageType_EXCEPTION} ->
            handle_application_exception(State);

        #protocol_message_begin{type = ?tMessageType_REPLY} ->
            handle_reply(State, Function, ReplyType)
    end.

handle_reply(State = #state{protocol = Proto,
                            service = Service},
             Function,
             ReplyType) ->
    {struct, ExceptionFields} = Service:function_info(Function, exceptions),
    ReplyStructDef = {struct, [{0, ReplyType}] ++ ExceptionFields},
    {ok, Reply} = thrift_protocol:read(Proto, ReplyStructDef),
    ReplyList = tuple_to_list(Reply),
    true = length(ReplyList) == length(ExceptionFields) + 1,
    ExceptionVals = tl(ReplyList),
    Thrown = [X || X <- ExceptionVals,
                   X =/= undefined],
    Result =
        case Thrown of
            [] when ReplyType == {struct, []} ->
                {ok, ok};
            [] ->
                {ok, hd(ReplyList)};
            [Exception] ->
                {exception, Exception}
        end,
    ok = thrift_protocol:read(Proto, message_end),
    Result.

handle_application_exception(State = #state{protocol = Proto}) ->
    {ok, Exception} = thrift_protocol:read(Proto,
                                           ?TApplicationException_Structure),
    ok = thrift_protocol:read(Proto, message_end),
    XRecord = list_to_tuple(
                ['TApplicationException' | tuple_to_list(Exception)]),
    error_logger:error_msg("X: ~p~n", [XRecord]),
    true = is_record(XRecord, 'TApplicationException'),
    {exception, XRecord}.